[Codegen][Hybrid Script]Unknown call type for Halide Call Generated in Hybrid Script

There is an issue when dealing with dynamic shape buffer in hybrid script:

import tvm

from tvm.te.hybrid import script

@script
def hybrid_test(data):
    batch_size = data.shape[0]
    out = output_tensor((batch_size,), data.dtype)
    for i in range(batch_size):
        out[i] = data[i, 0]

    return out

dshape = (tvm.tir.Var("any_dim", "int32"), 10)
d = tvm.te.placeholder(dshape, name="data", dtype="float32")
o = hybrid_test(d)
s = tvm.te.create_schedule(o.op)
print(tvm.lower(s, [d, o], simple_mode=True))
f = tvm.build(s, [d, o], "llvm")

For this simple example, I got the following error:

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000111bb3cee tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 766
  [bt] (7) 8   libtvm.dylib                        0x0000000111bc6970 tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*) + 80
  [bt] (6) 7   libtvm.dylib                        0x000000011124f69c tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const + 284
  [bt] (5) 6   libtvm.dylib                        0x0000000111bb3a26 tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*) + 54
  [bt] (4) 5   libtvm.dylib                        0x0000000111ba3e1c tvm::NodeFunctor<llvm::Value* (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<llvm::Value* (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<llvm::Value* (tvm::PrimExpr const&)>*) const + 284
  [bt] (3) 4   libtvm.dylib                        0x0000000111bc2a9d tvm::codegen::CodeGenLLVM::VisitExpr_(tvm::tir::EQNode const*) + 29
  [bt] (2) 3   libtvm.dylib                        0x0000000111ba3e1c tvm::NodeFunctor<llvm::Value* (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<llvm::Value* (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<llvm::Value* (tvm::PrimExpr const&)>*) const + 284
  [bt] (1) 2   libtvm.dylib                        0x0000000111bc379e tvm::codegen::CodeGenLLVM::VisitExpr_(tvm::tir::CallNode const*) + 254
  [bt] (0) 1   libtvm.dylib                        0x0000000111209c49 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/Users/wayao/Documents/tvm/src/target/llvm/codegen_llvm.cc", line 1123
TVMError: Unknown call type name= batch_size call_type= 3

The reason is that we are passing a symbolic shape into hybrid function and trying to bind a local var(batch_size) with it. In hybrid parser, it coverts batch_size to a Halide Call, for which llvm codegen complains. This is the tvm IR generated:

// attr [batch_size] storage_scope = "global"
allocate batch_size[int32 * 1]
// attr [0] extern_scope = 0
batch_size[0] = any_dim
for (i, 0, batch_size[0]) {
  hybrid_test[i] = data[(i*stride)]
}

If I don’t declare batch_size:

import tvm

from tvm.te.hybrid import script

@script
def hybrid_test(data):
    out = output_tensor((data.shape[0],), data.dtype)
    for i in range(data.shape[0]):
        out[i] = data[i, 0]

    return out

dshape = (tvm.tir.Var("any_dim", "int32"), 10)
d = tvm.te.placeholder(dshape, name="data", dtype="float32")
o = hybrid_test(d)
s = tvm.te.create_schedule(o.op)
print(tvm.lower(s, [d, o], simple_mode=True))
f = tvm.build(s, [d, o], "llvm")

There is no issue with this code snippet, and the generated IR:

// attr [0] extern_scope = 0
for (i, 0, any_dim) {
  hybrid_test[(i*stride)] = data[(i*stride)]
}

Any idea how we can fix this? @tqchen @were

One workaround is that I can pass batch_size into hybrid function. However, I still wonder whether we have a better solution.

@were can you take a look? Thanks!