Description
For the Relax IRs shown below, if we use the pass sequence relax.transform.RealizeVDevice()
and relax.transform.FuseTIR()
, it will crash and throw " Variable outer_func was used before its definition!". However, as we can see, the IRs define the function outer_func first and then call it. This IR seems valid. Why did this crash happen?
TraceBack
Traceback (most recent call last):
File "test.py", line 29, in <module>
mod = relax.transform.FuseTIR()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
27: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
26: tvm::transform::Pass::operator()(tvm::IRModule) const
25: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
24: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
22: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
20: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
19: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::DeadCodeElimination(tvm::runtime::Array<tvm::runtime::String, void>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::DeadCodeElimination(tvm::runtime::Array<tvm::runtime::String, void>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
18: tvm::relax::DeadCodeElimination(tvm::IRModule const&, tvm::runtime::Array<tvm::runtime::String, void>)
17: tvm::relax::RemoveAllUnused(tvm::RelayExpr)
16: tvm::relax::CollectVarUsage(tvm::RelayExpr const&)
15: tvm::relax::UDChain::Collect(tvm::RelayExpr const&)
14: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
13: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
12: tvm::relax::UDChain::VisitExpr_(tvm::relax::FunctionNode const*)
11: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
10: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
9: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
8: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
7: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
6: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
5: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
4: tvm::relax::UDChain::VisitBinding_(tvm::relax::VarBindingNode const*)
3: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
2: _ZZN3tvm5relax11ExprVisitor22InitVisitBindingVTabl
1: _ZN3tvm5relax11ExprVisitor13VisitBinding_EPKNS0_14VarBindingNodeEPKNS0_12F
0: tvm::relax::UDChain::VisitVarDef(tvm::relax::Var const&)
File "/software/tvm/src/relax/analysis/udchain.cc", line 75
TVMError: Check failed: (!usage_map.count(var)) is false: Variable outer_func was used before its definition
Minimum reproducible script
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"attr": 10})
I.module_global_infos({"vdevice": [I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
@R.function
def outer_func(c1: R.Tensor((2, 3), dtype="float32")) -> R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True):
@R.function
def inner_func(x1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
s: R.Tensor((2, 3), dtype="float32") = R.add(x1, c1)
return s
return inner_func
in_call: R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True) = outer_func(x)
res: R.Tensor((2, 3), dtype="float32") = in_call(y)
return res
mod = Module
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.RealizeVDevice()(mod) # crash: RealizeVDevice + FuseTIR
mod = relax.transform.FuseTIR()(mod)
Your comments and suggestions are highly appreciated!