Model compilation fails on GPU target

Hello, I am trying to compile a model (available here), but I get the following error when targeting a GPU backend (tested: vulkan, opencl and cuda) with TVM 0.10:

$ tvmc compile --target vulkan model_files/bcnn.onnx 
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
Traceback (most recent call last):
  File "/usr/local/bin/tvmc", line 11, in <module>
    load_entry_point('tvm==0.10.0.dev1+gda7b48f94', 'console_scripts', 'tvmc')()
  File "/usr/local/lib/python3.8/dist-packages/tvm-0.10.0.dev1+gda7b48f94-py3.8-linux-aarch64.egg/tvm/driver/tvmc/main.py", line 115, in main
    sys.exit(_main(sys.argv[1:]))
  File "/usr/local/lib/python3.8/dist-packages/tvm-0.10.0.dev1+gda7b48f94-py3.8-linux-aarch64.egg/tvm/driver/tvmc/main.py", line 103, in _main
    return args.func(args)
  File "/usr/local/lib/python3.8/dist-packages/tvm-0.10.0.dev1+gda7b48f94-py3.8-linux-aarch64.egg/tvm/driver/tvmc/compiler.py", line 180, in drive_compile
    compile_model(
  File "/usr/local/lib/python3.8/dist-packages/tvm-0.10.0.dev1+gda7b48f94-py3.8-linux-aarch64.egg/tvm/driver/tvmc/compiler.py", line 353, in compile_model
    graph_module = build(
  File "/usr/local/lib/python3.8/dist-packages/tvm-0.10.0.dev1+gda7b48f94-py3.8-linux-aarch64.egg/tvm/driver/tvmc/compiler.py", line 428, in build
    return relay.build(
  File "/usr/local/lib/python3.8/dist-packages/tvm-0.10.0.dev1+gda7b48f94-py3.8-linux-aarch64.egg/tvm/relay/build_module.py", line 364, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/usr/local/lib/python3.8/dist-packages/tvm-0.10.0.dev1+gda7b48f94-py3.8-linux-aarch64.egg/tvm/relay/build_module.py", line 161, in build
    self._build(
  File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 276, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  14: TVMFuncCall
  13: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16P
  12: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  11: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  10: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  9: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
  8: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_3tir9transform13MakePackedAPIEiEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  1: tvm::tir::transform::MakePackedAPI(int)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const [clone .isra.0]
  0: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&, int)
  File "../src/tir/transforms/make_packed_api.cc", line 329
TVMError: Not all Vars are passed in api_args:  'blockIdx.y'  is not bound to any variables

The compilation is fine with TVM 0.8. Using git bisect, the commit introducing the regression comes from https://github.com/apache/tvm/pull/10423.

More specifically it is the changes in the computation of transposed conv from dc to c // out_channels * (inp_channels // groups) + dc that seems to be the cause of the issue. But my understanding of TVM is not good enough to understand what that change implies, and how it is linked to blockIdx.y.

Should this be raised as a bug in the github repository? Any help is appreciated.

ping @masahi That’s a model we would use in Autoware. And you may have some better understanding of the change in question as you reviewed the PR.

Ok I’m on it. I can reproduce the issue, and also confirmed that reverting that PR makes the error go away.

Dumping the problematic IR showed that blockIdx.y is indeed used in a weird way:

  allocate data_pad[float32 * ((floordiv(((blockIdx.y*2) + 1), 3)*4608) + 576)], storage_scope = shared
  allocate p1.shared[float32 * ((floordiv(((blockIdx.y*2) + 1), 3)*12288) + 1536)], storage_scope = shared

blockIdx.y is a variable so it can never be used as part of the size of a buffer.

Replacing SplitEntity at https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/conv2d_transpose.py#L164 with SplitEntity([-1, 1, 4, 1])) fixes it.

The problem is that we are trying to split a small channel (12 in your model) with a large factor (64), and this apparently breaks algebraic simplification. Note that your model has some conv2d_transpose but all of them uses group = 1. So the PR should be a nop if algebraic simplification does the job.

I’ll send a PR on it tomorrow.

1 Like

This is fixed by https://github.com/apache/tvm/pull/13341

I can confirm it works, thanks!