I run into an issue where if I have a Relax graph with mixed TIR and Relax functions, the Zero pipeline raises an error.
It says that there are multiple buffers mapped to the same relax var.
The error is generated by the combination of LegalizeOps
, AnnotateTIROpPattern
, FuseOps
, and FuseTIR
passes, the error being raised when FuseTIR tries to fuse the TIR and Relax functions.
The generated lv
relax var (the output of the TIR function from call_tir) is mapped to both the output buffer of the TIR function and the input of the separate, Relax generated TIR function, using the previous output as input.
code example
from tvm import te, relax
import tvm.relax.testing.nn as nn
Ate = te.placeholder((3, 3), dtype="float32", name="A")
res = te.compute((3, 3), lambda i, j: Ate[i, j] + 1, "res")
f2 = te.create_prim_func([Ate, res])
bb = relax.BlockBuilder()
gv = bb.add_func(f2, "func")
with bb.function("m"):
with bb.dataflow():
a = nn.Placeholder([3, 3], "float32")
c = relax.call_tir(gv, a, out_sinfo=relax.TensorStructInfo([3, 3], "float32"))
z = relax.op.add(c, c)
o = bb.emit_output(z)
bb.emit_func_output(o, [a])
mod = relax.get_pipeline("zero")(bb.get()) # will raise error
# the same error can be repeated by
import tvm.relax.transform as transform
mod = transform.LegalizeOps(enable_warning=True)(bb.get())
mod = transform.AnnotateTIROpPattern()(mod)
mod = transform.FuseOps()(mod)
mod = transform.FuseTIR()(mod) # will raise the error
tvm.error.InternalError: Traceback (most recent call last):
17: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
16: tvm::transform::Pass::operator()(tvm::IRModule) const
15: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
14: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
13: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
12: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
10: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
9: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::FuseTIR()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::FuseTIR()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
8: tvm::relax::FuseTIR(tvm::IRModule)
7: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
6: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, tvm::GlobalVar const&)
5: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode const*)
4: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
3: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
2: _ZZN3tvm5relax11ExprVisitor22InitVisitBindingVTableEvENUlRKNS_7runtime9ObjectRefEPS1_PKNS0_14VarBindingNod
1: tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode const*, tvm::RelayExpr const&, bool)
0: tvm::relax::RelaxToTIRVarMapCollector::CollectVarMapping(tvm::relax::CallNode const*, tvm::RelayExpr const&, bool)::{lambda(tvm::tir::Buffer, tvm::RelayExpr)#1}::operator()(tvm::tir::Buffer, tvm::RelayExpr) const [clone .isra.0]
File "~/tvm/src/relax/transform/fuse_tir.cc", line 452
InternalError: Check failed: (StructuralEqual()((*it).second, new_buf)) is false: Inconsistent buffers res and lv mapped to the same relax var: lv