TVMError: Check failed: (!usage_map.count(var)) is false

Description

For the Relax IRs shown below, if we use the pass sequence relax.transform.RealizeVDevice() and relax.transform.FuseTIR(), it will crash and throw " Variable outer_func was used before its definition!". However, as we can see, the IRs define the function outer_func first and then call it. This IR seems valid. Why did this crash happen?

TraceBack

Traceback (most recent call last):
  File "test.py", line 29, in <module>
    mod = relax.transform.FuseTIR()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  27: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  26: tvm::transform::Pass::operator()(tvm::IRModule) const
  25: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  24: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  22: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::DeadCodeElimination(tvm::runtime::Array<tvm::runtime::String, void>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::DeadCodeElimination(tvm::runtime::Array<tvm::runtime::String, void>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  18: tvm::relax::DeadCodeElimination(tvm::IRModule const&, tvm::runtime::Array<tvm::runtime::String, void>)
  17: tvm::relax::RemoveAllUnused(tvm::RelayExpr)
  16: tvm::relax::CollectVarUsage(tvm::RelayExpr const&)
  15: tvm::relax::UDChain::Collect(tvm::RelayExpr const&)
  14: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  13: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  12: tvm::relax::UDChain::VisitExpr_(tvm::relax::FunctionNode const*)
  11: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
  10: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  9: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  8: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
  7: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
  6: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  5: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
  4: tvm::relax::UDChain::VisitBinding_(tvm::relax::VarBindingNode const*)
  3: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
  2: _ZZN3tvm5relax11ExprVisitor22InitVisitBindingVTabl
  1: _ZN3tvm5relax11ExprVisitor13VisitBinding_EPKNS0_14VarBindingNodeEPKNS0_12F
  0: tvm::relax::UDChain::VisitVarDef(tvm::relax::Var const&)
  File "/software/tvm/src/relax/analysis/udchain.cc", line 75
TVMError: Check failed: (!usage_map.count(var)) is false: Variable outer_func was used before its definition

Minimum reproducible script

import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
    @R.function
    def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):

        @R.function
        def outer_func(c1: R.Tensor((2, 3), dtype="float32")) -> R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True):

            @R.function
            def inner_func(x1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
                s: R.Tensor((2, 3), dtype="float32") = R.add(x1, c1)
                return s
            return inner_func
        in_call: R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True) = outer_func(x)
        res: R.Tensor((2, 3), dtype="float32") = in_call(y)
        return res

mod = Module
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.RealizeVDevice()(mod)  # crash: RealizeVDevice + FuseTIR
mod = relax.transform.FuseTIR()(mod)

Your comments and suggestions are highly appreciated!

It looks like this is related to [Bug] [Relax] cannot remove the hint_on_device · Issue #17205 · apache/tvm · GitHub, which results from RealizeVDevice mutating Relax expressions in-place. This isn’t legal for a IRModule transform to do, as the expressions in an IRModule may be shared references used in multiple independent IRModules. This in-place mutation can make some of the Relax analysis tools produce false errors, since the mutation may only have been valid in one occurrence of an expression, but gets applied across all shared references to that expression.

I can reproduce your example on TVM main, and it runs successfully when the patch in https://github.com/apache/tvm/pull/17213 is applied. This PR refactors RealizeVDevice to remove the in-place mutation altogether.

1 Like