[arm][bitserial][BUG] cryptic error message

Hello, when I try to use the bit-serial convolution operator for arm CPUs together with some other operator, for example a ReLU afterwards, compiling fails with a error message I cannot understand.

The error vanishes if the ReLU after the bit-serial convolution is removed. A cast to float32 before the ReLU doesn’t change this behaviour. The error says something about InferBound and PassDownDomain and that an internal invariant is violated. Has someone a idea what is going on there? Thanks a lot!

Here is an example code snippet:

# imports
import tvm
from tvm import relay
from tvm.relay.testing.init import create_workload
## define relay network
def create_network():
    # float32 conv2d input layer
    data = relay.var( 'data', shape=(1,3,32,32), dtype='float32' )
    data = relay.nn.conv2d( data, 
            relay.var("conv1_weight", relay.TensorType((32,3,3,3),dtype='float32')), # OIHW
            strides=(1,1),
            padding=(1,1),
            channels=32,
            kernel_size=(3,3),
            data_layout='NCHW',
            kernel_layout='OIHW'
            )
    data = relay.nn.relu( data )
    # transform layout and dtype
    data = relay.layout_transform( data, "NCHW", "NHWC" )
    data = relay.cast( data, 'int16' )
    # bit-serial conv2d layer
    data = relay.nn.bitserial_conv2d( data,
            relay.var("conv2_weight", relay.TensorType( (3,3,32,32), 'uint32')), # HWIO
            strides=(1,1),
            padding=(1,1),
            channels=32,
            kernel_size=(3,3),
            activation_bits=1,
            weight_bits=1,
            data_layout="NHWC",
            kernel_layout="HWIO",
            pack_dtype="uint8",
            out_dtype="int16",
            unipolar=True
            )
    data = relay.nn.relu( data ) # if this line is removed, model compiles!
    # create workload
    args = relay.analysis.free_vars( data )
    net = relay.Function( args, data )
    mod, params = create_workload( net )
    return mod, params
## compile
def compile_network(mod, params, target=tvm.target.arm_cpu('rasp4b')):
    with relay.build_config( opt_level=3 ):
        graph, lib, params = relay.build_module.build( mod, target=target, params=params)
        return lib
## main
if __name__=='__main__':
    mod, params = create_network()
    lib = compile_network( mod, params )
    print('script completely executed!')

And this is the error message:

Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0 -mattr=+neon -mcpu=cortex-a72 -model=bcm2711 -mtriple=armv8l-linux-gnueabihf, workload=('conv2d_nchw_spatial_pack.arm_cpu', ('TENSOR', (1, 3, 32, 32), 'float32'), ('TENSOR', (32, 3, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0 -mattr=+neon -mcpu=cortex-a72 -model=bcm2711 -mtriple=armv8l-linux-gnueabihf, workload=('bitserial_conv2d_nhwc.arm_cpu', ('TENSOR', (1, 32, 32, 32), 'int16'), ('TENSOR', (3, 3, 32, 32), 'uint32'), (1, 1), (1, 1, 1, 1), 1, 1, 'uint8', 'int16', 1). A fallback configuration is used, which may bring great performance regression.
Traceback (most recent call last):
  File "tvm-bitserial-operator-issue.py", line 62, in <module>
    lib = compile_network( mod, params )
  File "tvm-bitserial-operator-issue.py", line 56, in compile_network
    graph, lib, params = relay.build_module.build( mod, target=target, params=params)
  File "tvm/python/tvm/relay/build_module.py", line 290, in build
    graph_json, runtime_mod, params = bld_mod.build(mod=ir_mod, target=target, params=params)
  File "tvm/python/tvm/relay/build_module.py", line 136, in build
    self._build(mod, target, target_host)
  File "tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  11: TVMFuncCall
  10: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  9: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
  8: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  7: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::relay::Function)
  6: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  5: _ZZN3tvm5relay11ExprFunctorIFSt6vectorINS0_7backend12GraphNodeRefESaIS4_EER
  4: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::CallNode const*)
  3: tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::relay::CompileEngine&, tvm::relay::CCacheKey&>(tvm::relay::CompileEngine&, tvm::relay::CCacheKey&) const
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::relay::CachedFunc (tvm::relay::CompileEngine, tvm::relay::CCacheKey)>::AssignTypedLambda<tvm::relay::{lambda(tvm::relay::CompileEngine, tvm::relay::CCacheKey)#9}>(tvm::relay::{lambda(tvm::relay::CompileEngine, tvm::relay::CCacheKey)#9}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) [clone .cold]
  File "tvm/python/tvm/relay/backend/_backend.py", line 50, in lower
    f = tvm.driver.lower(sch, inputs, name=func_name)
  File "tvm/python/tvm/driver/build_module.py", line 164, in lower
    mod = form_irmodule(sch, args, name, binds)
  File "tvm/python/tvm/driver/build_module.py", line 106, in form_irmodule
    bounds = schedule.InferBound(sch)
  File "tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::Map<tvm::tir::IterVar, tvm::Range, void, void> (tvm::te::Schedule const&)>::AssignTypedLambda<tvm::runtime::Map<tvm::tir::IterVar, tvm::Range, void, void> (*)(tvm::te::Schedule const&)>(tvm::runtime::Map<tvm::tir::IterVar, tvm::Range, void, void> (*)(tvm::te::Schedule const&), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::te::InferBound(tvm::te::Schedule const&)
  0: tvm::te::PassDownDomain(tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*, tvm::arith::Analyzer*, bool)
  File "tvm/src/te/schedule/message_passing.cc", line 112
  File "tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "tvm/python/tvm/relay/backend/_backend.py", line 58, in lower
    raise RuntimeError(msg)
  File "tvm/python/tvm/relay/backend/_backend.py", line 50, in lower
    f = tvm.driver.lower(sch, inputs, name=func_name)
  File "tvm/python/tvm/driver/build_module.py", line 164, in lower
    mod = form_irmodule(sch, args, name, binds)
  File "tvm/python/tvm/driver/build_module.py", line 106, in form_irmodule
    bounds = schedule.InferBound(sch)
  File "tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::Map<tvm::tir::IterVar, tvm::Range, void, void> (tvm::te::Schedule const&)>::AssignTypedLambda<tvm::runtime::Map<tvm::tir::IterVar, tvm::Range, void, void> (*)(tvm::te::Schedule const&)>(tvm::runtime::Map<tvm::tir::IterVar, tvm::Range, void, void> (*)(tvm::te::Schedule const&), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::te::InferBound(tvm::te::Schedule const&)
  0: tvm::te::PassDownDomain(tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > >*, tvm::arith::Analyzer*, bool)
  File "tvm/src/te/schedule/message_passing.cc", line 112
TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: (allow_missing) is false: 

During handling of the above exception, another exception occurred:

TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------
  Check failed: (allow_missing) is false: 
Error during compile function
-----------------------------
#[version = "0.0.5"]
fn (%p0: Tensor[(1, 32, 32, 32), int16], %p1: Tensor[(3, 3, 32, 32), uint32], Primitive=1) -> Tensor[(1, 32, 32, 32), int16] {
  %0 = nn.bitserial_conv2d(%p0, %p1, padding=[1, 1, 1, 1], channels=32, data_layout="NHWC", kernel_layout="HWIO", pack_dtype="uint8", out_dtype="int16") /* ty=Tensor[(1, 32, 32, 32), int16] */;
  nn.relu(%0) /* ty=Tensor[(1, 32, 32, 32), int16] */
}

I used TVM commit 390b4d1ce32533810fc8d65dcf211f6d147764aa.

Hi @comaniac do you have an idea whats going on or whom I could ask? It sounds not normal that an bit-serial conv2d layer itself is ok, but if you add an ReLU afterwards you get such an error:

Thank you!

Seems like the TOPI schedule doesn’t support post elementwise ops:

It should have the logic like this example:

1 Like

Thank you very much for pointing me in the right direction. :grinning: I will try solving it, but unfortunately I am not familiar with writing TOPI schedules.

In any case, I opened a related github issue. Thank you!

Solved with #7929. Thanks to @comaniac and @jwfromm.