When I tested a simple network conv2d + sigmoid with AlterOpLayout enabled, the output tensor layout turned out to be NCHWc instead of NCHW (the layout of the input tensor). Is this the expected behavior? If not, how can this be fixed?
This is my code:
from __future__ import absolute_import as _abs
import numpy as np
from nnvm import symbol as sym
from nnvm.top import registry as reg
from nnvm.testing import utils
import nnvm.compiler
import nnvm.graph as graph
import tvm
from tvm.contrib import graph_runtime
def test_alter_op_layout():
input_name = "data"
input_shape = (1, 3, 224, 224)
data = sym.Variable(input_name, shape=input_shape)
conv = sym.conv2d(data, name="conv", channels=3,
kernel_size=(3,3), padding=(1,1),
use_bias=True, layout="NCHW")
sigmoid = sym.sigmoid(conv)
batch_size = 1
net, params = utils.create_workload(sigmoid, batch_size, (3, 224, 224))
opt_level = 3
target = 'llvm -mcpu=core-avx2'
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, target, shape={input_name: input_shape}, params=params)
print(graph.symbol().debug_str())
ctx = tvm.context(target, 0)
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
m.set_input(input_name, data_tvm)
m.set_input(**params)
# execute
m.run()
# get outputs
# output_shape = input_shape # error when output_shape is 4D
output_shape = (1, 1, 224, 224, 3)
tvm_output = m.get_output(0, tvm.nd.empty(output_shape, dtype)).asnumpy()
print(tvm_output.shape)
if __name__ == "__main__":
test_alter_op_layout()
This is the error I see when I set the output_shape=input_shape
---------------------------------------------------------------------------
TVMError Traceback (most recent call last)
<ipython-input-8-5a505249d5ef> in <module>()
44
45 if __name__ == "__main__":
---> 46 test_alter_op_layout()
47
<ipython-input-8-5a505249d5ef> in test_alter_op_layout()
40 # output_shape = (1, 1, 224, 224, 3)
41 output_shape = input_shape
---> 42 tvm_output = m.get_output(0, tvm.nd.empty(output_shape, dtype)).asnumpy()
43 print(tvm_output.shape)
44
~/github/tvm/python/tvm/contrib/graph_runtime.py in get_output(self, index, out)
176 """
177 if out:
--> 178 self._get_output(index, out)
179 return out
180
~/github/tvm/python/tvm/_ffi/_cython/function.pxi in tvm._ffi._cy3.core.FunctionBase.__call__()
~/github/tvm/python/tvm/_ffi/_cython/function.pxi in tvm._ffi._cy3.core.FuncCall()
~/github/tvm/python/tvm/_ffi/_cython/function.pxi in tvm._ffi._cy3.core.FuncCall3()
~/github/tvm/python/tvm/_ffi/_cython/base.pxi in tvm._ffi._cy3.core.CALL()
TVMError: [22:50:34] /Users/hlu/github/tvm/src/runtime/graph/graph_runtime.cc:151: Check failed: data->ndim == data_out->ndim (5 vs. 4)
Stack trace returned 10 entries:
[bt] (0) 0 libtvm.dylib 0x000000011a11c070 dmlc::StackTrace() + 288
[bt] (1) 1 libtvm.dylib 0x000000011a11be0f dmlc::LogMessageFatal::~LogMessageFatal() + 47
[bt] (2) 2 libtvm.dylib 0x000000011a718e5f tvm::runtime::GraphRuntime::CopyOutputTo(int, DLTensor*) + 527
[bt] (3) 3 libtvm.dylib 0x000000011a718b81 std::__1::__function::__func<tvm::runtime::GraphRuntime::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, std::__1::shared_ptr<tvm::runtime::ModuleNode> const&)::$_5, std::__1::allocator<tvm::runtime::GraphRuntime::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, std::__1::shared_ptr<tvm::runtime::ModuleNode> const&)::$_5>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 161
[bt] (4) 4 libtvm.dylib 0x000000011a6e46a6 TVMFuncCall + 70
[bt] (5) 5 core.cpython-36m-darwin.so 0x000000011d92298d __pyx_f_3tvm_4_ffi_4_cy3_4core_FuncCall(void*, _object*, TVMValue*, int*) + 477
[bt] (6) 6 core.cpython-36m-darwin.so 0x000000011d928387 __pyx_pw_3tvm_4_ffi_4_cy3_4core_12FunctionBase_5__call__(_object*, _object*, _object*) + 55
[bt] (7) 7 Python 0x000000010cae2e8c _PyObject_FastCallDict + 143
[bt] (8) 8 Python 0x000000010cb7f0fa call_function + 441
[bt] (9) 9 Python 0x000000010cb77ff7 _PyEval_EvalFrameDefault + 4811