[VTA]GraphModule Error

Hello everyone. I want to deploy Unet on PYNQ Z2. When I use sim mode, I can excute the code without any problems. But if I change the mode to pynq, I will meet a GraphModule Error(Check failed: (data != nullptr) is false: ). My code is listed below, and my tvm version is 0.9.

env = vta.get_env()
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu
if env.TARGET not in ["sim", "tsim", "intelfocl"]:

    # Get remote from tracker node if environment variable is set.
    # To set up the tracker, you'll need to follow the "Auto-tuning
    # a convolutional network for VTA" tutorial.
    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
    # Otherwise if you have a device you want to program directly from
    # the host, make sure you've set the variables below to the IP of
    # your board.
    device_host = os.environ.get("VTA_RPC_HOST", "192.168.31.51")
    device_port = os.environ.get("VTA_RPC_PORT", "9091")
    if not tracker_host or not tracker_port:
        remote = rpc.connect(device_host, int(device_port))
    else:
        remote = autotvm.measure.request_remote(
            env.TARGET, tracker_host, int(tracker_port), timeout=10000
        )

    # Reconfigure the JIT runtime and FPGA.
    # You can program the FPGA with your own custom bitstream
    # by passing the path to the bitstream file instead of None.
    reconfig_start = time.time()
    vta.reconfig_runtime(remote)
    vta.program_fpga(remote, bitstream=None)
    reconfig_time = time.time() - reconfig_start
    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))

else:
    remote = rpc.LocalSession()

ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)

# onnx_model = onnx.load("Unet1.onnx")
model = torch.jit.load("Unet.pt")

test_path = '/home/dengbw/PycharmProjects/Unet/data/test/0.png'

save_res_path = test_path.split('.')[0] + '_VTA_res.png'
img = cv2.imread(test_path)
# 转为灰度图
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# 转为batch为1,通道为1,大小为512*512的数组
img = img.reshape(1, 1, img.shape[0], img.shape[1])
# 转为tensor
img_tensor = torch.from_numpy(img)

img_data = img_tensor

input_name = "input.1"
shape_dict = {input_name: img_data.shape}

# mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
mod, params = relay.frontend.from_pytorch(model, [('input.1',(img_data.shape))])
with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout", "tir.CommonSubexprElimTIR"}):
    lib = relay.build(
        mod, target=tvm.target.Target(target, host=env.target_host), params=params
    )

# print(lib)

temp = utils.tempdir()
lib.export_library(temp.relpath("graphlib.tar"))
remote.upload(temp.relpath("graphlib.tar"))
lib = remote.load_module("graphlib.tar")
module = graph_executor.GraphModule(lib["default"](ctx))

This is the error I encountered.

Reconfigured FPGA and RPC runtime in 8.74s!
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
Traceback (most recent call last):
  File "/home/dengbw/PycharmProjects/pynq/opt_Unet.py", line 84, in <module>
    module = graph_executor.GraphModule(lib["default"](ctx))
  File "/home/dengbw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm.error.RPCError: Traceback (most recent call last):
  4: TVMFuncCall
  3: tvm::runtime::RPCWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  2: tvm::runtime::RPCClientSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)
  1: tvm::runtime::RPCEndpoint::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)>)
  0: tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)
  15: TVMFuncCall
  14: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  13: tvm::runtime::RPCServerLoop(int)
  12: tvm::runtime::RPCEndpoint::ServerLoop()
  11: tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)
  10: tvm::runtime::RPCEndpoint::EventHandler::HandleNextEvent(bool, bool, std::function<void (tvm::runtime::TVMArgs)>)
  9: tvm::runtime::RPCEndpoint::EventHandler::HandleProcessPacket(std::function<void (tvm::runtime::TVMArgs)>)
  8: tvm::runtime::RPCSession::AsyncCallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::RPCCode, tvm::runtime::TVMArgs)>)
  7: tvm::runtime::LocalSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::GraphExecutorFactory::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  5: tvm::runtime::GraphExecutorFactory::ExecutorCreate(std::vector<DLDevice, std::allocator<DLDevice> > const&)
  4: tvm::runtime::GraphExecutor::Init(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::Module, std::vector<DLDevice, std::allocator<DLDevice> > const&, tvm::runtime::PackedFunc)
  3: tvm::runtime::GraphExecutor::SetupStorage()
  2: tvm::runtime::NDArray::Empty(tvm::runtime::ShapeTuple, DLDataType, DLDevice, tvm::runtime::Optional<tvm::runtime::String>)
  1: tvm::runtime::DeviceAPI::AllocDataSpace(DLDevice, int, long long const*, DLDataType, tvm::runtime::Optional<tvm::runtime::String>)
  0: 0xb28a6c9f
  File "/home/dengbw/tvm/src/runtime/rpc/rpc_endpoint.cc", line 376
RPCError: Error caught from RPC call:
[15:57:03] /home/xilinx/tvm/vta/runtime/runtime.cc:179: Check failed: (data != nullptr) is false: 

Any one can help me?

This is because the memory allocate get failed, there are 2 possibility, #1 not run ‘start_rpc_server.sh’ in the root mode by using sudo, #2. pynq platform have a bug that after the application crash the allocated buffer not get freed, then this may happened after experience a rpc server crash, reboot is a work ground to fix such issue.

Thanks for your reply. For the reason#1, I have run the 'start_rpc_server.sh’ in the root mode by using sudo. And I can execute two official examples successfully on pynq board as below.

Here is the runtime screenshot of pynq board

After finishing serving, the host side got the check failed error. The image version on pynq is v2.6. Is reboot mean restart pynq board? I have tried many times, and alwalys got the same error((Check failed: (data != nullptr) is false: )).

another possibility is your model have a operator require very large memory, please print your model(print(mod[“main”]) to check whether it have such problem, another debug method is to add debug log in vta/runtime/runtime.cc:179 to print the ‘size’ and check which allocate function failed.

Ok, I will check it.

Hi, has your problem been solved? I have the same problem, can you give me some advice? Looking forward to your reply

I haven’t solved it yet. Because I’m preparing for my final projects and exams these days. I will try to slove the problem after final. If you have solution about it, I also wish your reply.

Ok, thank you very much for your reply, and good luck with your exam :grinning_face_with_smiling_eyes:

After I print the size, I find there are two modules need too much memory. image image

So I try to do quantization for Unet. I add the block into my original code. However, I got a new error.

 Reconfigured FPGA and RPC runtime in 5.28s!
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
Traceback (most recent call last):
  File "/home/dengbw/PycharmProjects/pynq/untitled.py", line 95, in <module>
    mod, target=tvm.target.Target(target, host=env.target_host), params=params
  File "/home/dengbw/tvm/python/tvm/relay/build_module.py", line 485, in build
    mod_name=mod_name,
  File "/home/dengbw/tvm/python/tvm/relay/build_module.py", line 202, in build
    self._build(mod, target, target_host, executor, runtime, workspace_memory_pools, mod_name)
  File "/home/dengbw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  32: TVMFuncCall
  31: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  30: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  29: tvm::relay::backend::RelayBuildModule::OptimizeImpl(tvm::IRModule)
  28: tvm::transform::Pass::operator()(tvm::IRModule) const
  27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  25: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  24: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  23: _ZN3tvm7runtime13PackedFuncObj
  22: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  21: tvm::relay::alter_op_layout::AlterOpLayout(tvm::RelayExpr const&)
  20: tvm::relay::ForwardRewrite(tvm::RelayExpr const&, tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)> const&, std::function<tvm::runtime::ObjectRef (tvm::relay::Call const&)>, std::function<tvm::RelayExpr (tvm::RelayExpr const&)>)
  19: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  18: void tvm::relay::ExpandDataflow<tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}) [clone .isra.0]
  17: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  16: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
  15: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  14: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  13: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  12: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  11: void tvm::relay::ExpandDataflow<tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#2}, tvm::relay::ExpandDataflow, tvm::relay::ExpandDataflow<{lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1}>(tvm::RelayExpr, {lambda(tvm::RelayExpr const&)#1}, tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)::{lambda(tvm::RelayExpr const&)#1})::{lambda(tvm::RelayExpr const&)#1}) [clone .isra.0]
  10: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  9: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
  8: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  7: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  6: tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)
  5: tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)
  4: _ZN3tvm7runtime13PackedFuncObj
  3: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  2: tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)
  1: tvm::relay::alter_op_layout::AlterTransformMemorizerNode::CallWithNewLayouts(tvm::relay::Call const&, tvm::Attrs, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
  File "/home/dengbw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/dengbw/tvm/python/tvm/relay/op/nn/_nn.py", line 232, in alter_op_layout_conv2d
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
  File "/home/dengbw/Downloads/Program/envs/Pytorch_cpu/lib/python3.7/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/home/dengbw/tvm/python/tvm/target/generic_func.py", line 286, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/home/dengbw/tvm/python/tvm/topi/x86/conv2d_alter_op.py", line 61, in _alter_conv2d_layout
    relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
  File "/home/dengbw/tvm/python/tvm/relay/backend/te_compiler.py", line 203, in select_implementation
    outs = impl.compute(attrs, inputs, out_type)
  File "/home/dengbw/tvm/python/tvm/relay/op/op.py", line 126, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/home/dengbw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#4}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
  File "/home/dengbw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/dengbw/tvm/python/tvm/relay/op/strategy/generic.py", line 243, in _compute_conv2d
    return [topi_compute(*args)]
  File "/home/dengbw/tvm/python/tvm/topi/arm_cpu/conv2d_int8.py", line 192, in conv2d_nchw_int8
    data, kernel, strides, padding, dilation, layout, layout, out_dtype
  File "/home/dengbw/tvm/python/tvm/autotvm/task/topi_integration.py", line 165, in wrapper
    node = topi_compute(cfg, *args)
  File "/home/dengbw/tvm/python/tvm/topi/arm_cpu/conv2d_int8.py", line 99, in conv2d_NCHWc_int8
    out_dtype,
  File "/home/dengbw/tvm/python/tvm/topi/arm_cpu/conv2d_int8.py", line 43, in _get_default_config
    conv2d_generic.fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes=4, num_int8_elements=4)
  File "/home/dengbw/tvm/python/tvm/topi/generic/conv2d.py", line 98, in fallback_schedule_cpu_1x1_int8
    int32_lanes,
AssertionError: wkl.out_filter=1, int32_lanes=4

oh, I just notice that in your origin post you try to run a non-quantization model in VTA+FPGA instead of simulator, please note that VTA+FPGA only support int8 compute and a quantization is necessary as the tutorial shown.

after quantization, a graph_pack is necessary to convert nchw to nchwinic to support FPGA hardware spec, but seems like you did not do that, i think that should be your current problem.

Ok, I have added the quantization and graph_pack into my code. I want to know if there are some restrictions to choose the start_name and stop_name.

After adding that, I got this error, Check failed: src_shape.size() == src_axis.size() (6 vs. 4) : Input shape size 6 mismatch with the exepected shape size 4. The input size of Unet is a vector of length four. Is the error mean the input_shape changes to six? Why the shape is changed after quantization and graph_pack?

Reconfigured FPGA and RPC runtime in 5.95s!
Traceback (most recent call last):
  File "/home/dengbw/PycharmProjects/pynq/untitled.py", line 91, in <module>
    device_annot=(env.TARGET == "intelfocl"),
  File "/home/dengbw/tvm/vta/python/vta/top/graphpack.py", line 611, in graph_pack
    expr = run_opt_pass(expr, transform.InferType())
  File "/home/dengbw/tvm/vta/python/vta/top/graphpack.py", line 30, in run_opt_pass
    mod = opt_pass(mod)
  File "/home/dengbw/tvm/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/dengbw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay9transform9InferTypeEvEUlS5_RKS7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SH_SL_
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve() [clone .cold]
  12: TVMFuncCall
  11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay9transform9InferTypeEvEUlS5_RKS7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SH_SL_
  6: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  5: tvm::relay::TypeSolver::Solve()
  4: _ZN3tvm7runtime13PackedFuncObj
  3: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  2: tvm::relay::Resize2DRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  1: tvm::tir::BijectiveLayout::ForwardShape(tvm::runtime::Array<tvm::PrimExpr, void> const&) const
  0: tvm::tir::TransformShape(tvm::runtime::Array<tvm::PrimExpr, void> const&, tvm::runtime::Array<tvm::tir::IterVar, void> const&, tvm::runtime::Array<tvm::tir::IterVar, void> const&, tvm::runtime::Array<tvm::PrimExpr, void> const&)
  File "/home/dengbw/tvm/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [15:27:37] /home/dengbw/tvm/src/tir/ir/data_layout.cc:323: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: src_shape.size() == src_axis.size() (6 vs. 4) : Input shape size 6 mismatch with the exepected shape size 4

this error caused by Resize2D which current not supported by graph_pack yet, then the 6 dimension input shape from upper layer output not matched with what the operator expected, once I get a chance I will work for a patch, or if you have interest you can help for a fix.