Relay Function virtual_device property

While switching to TVMC, I noticed a “virtual_device” property on the top-level relay module function. It was not properly propagated through my relay passes and caused an assertion in lowering to TE, with:

Check failed: (!virtual_device->IsFullyUnconstrained()) is false

at:

  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/__main__.py", line 24, in <module>
    tvmc.main.main()
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 115, in main
    sys.exit(_main(sys.argv[1:]))
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 103, in _main
    return args.func(args)
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 173, in drive_compile
    compile_model(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 337, in compile_model
    graph_module = build(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 410, in build
    return relay.build(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 431, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 154, in build
    self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, mod_name)
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  29: TVMFuncCall
  28: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  27: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  25: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
  24: tvm::transform::Pass::operator()(tvm::IRModule) const
  23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  22: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec7LowerTEENS0_6StringENS_17CompilationConfigESt8functionIFvNS_8BaseFuncEEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
  18: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
  17: tvm::transform::Pass::operator()(tvm::IRModule) const
  16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
  13: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relay11ExprFuncto
  11: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  10: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
  9: _ZN3tvm5relay9tr
  8: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  7: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relay11ExprFuncto
  5: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode const*)
  4: tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var const&, tvm::RelayExpr const&)
  3: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  2: _ZZN3tvm5relay11ExprFuncto
  1: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  0: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  File "/home/user1/mlenv/deps/src/tvm/src/relay/backend/te_compiler.cc", line 885

I noticed that this property is sometimes updated manually after creating new copies of a function:

However, this was not always done and I had to patch the following cases to fix the compilation again:

diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index 1a16cc9be..d05a30626 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -131,6 +131,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx)
     // only process optimizable Relay Functions
     if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
       Function updated_func = pass_func(GetRef<Function>(function_node), updated_mod, pass_ctx);
+      updated_func->virtual_device_ = GetRef<Function>(function_node)->virtual_device();
       updates.push_back({kv.first, std::move(updated_func)});
     }
   }
diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index b9ca7d0e1..889031ed4 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -204,7 +204,10 @@ class ExprMutator(ExprFunctor):
     def visit_function(self, fn):
         new_params = [self.visit(x) for x in fn.params]
         new_body = self.visit(fn.body)
-        return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
+        func = Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
+        from tvm.relay.function import FunctionCopyVirtualDevice
+        FunctionCopyVirtualDevice(func, fn)
+        return func
 
     def visit_let(self, let):
         new_var = self.visit(let.var)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index f889f1e59..997fd1776 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -26,6 +26,10 @@
 from . import _ffi_api
 
 
+def FunctionCopyVirtualDevice(f1, f2):
+    _ffi_api.FunctionCopyVirtualDevice(f1, f2)
+
+
 @tvm._ffi.register_object("relay.Function")
 class Function(BaseFunc):
     """A function declaration expression.
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 63e74144e..bd3906731 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -127,6 +127,10 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
                        tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
       return Function(params, body, ret_type, ty_params, attrs);
     });
+TVM_REGISTER_GLOBAL("relay.ir.FunctionCopyVirtualDevice")
+    .set_body_typed([](Function f1, Function f2) {
+      f1->virtual_device_ = f2->virtual_device_;
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {

This does not seem like an elegant solution and I’m wondering why the virtual_device is not part of the Function() python interface. Would that be an appropriate solution?

@mbs-octoml @electriclilies

Hi Rafael, virtual device handling is unfortunately in a halfway-implemented state, and it’s been on my backlog for a while to wrap that up. Sorry about that! I’m hoping I can work on it in a few weeks as a break between other tasks.

There’s a few things to be done:

  • Populate the virtual_device field on every expression node, and remove the use of ‘on_device’ annotations to record device information.
  • As you are noticing, propagate virtual_device info through the passes by migrating to WithFields (which should have the happy side effect of also propagating spans and other generic information.)
  • Update the existing passes that need to understand virtual devices.

However I think consensus is this is better done in Relax, so another possibility is to backtrack.

Hi Mark, thank you for clarifying.

If I’m not interested in using the virtual_device feature, is there a way to disable it? The issue is that without the patch above, it is not possible to use any pass that is based on the ExprMutator in Python, because the TE Compiler complains with the above assert. If it cannot be disabled, I’d be interested to contribute a patch that fixes the issue.

Actually the first line of the patch is not required in my case, so the only missing copy for my use-case is in the Python ExprMutator. The WithFields construct does not seem to be available in Python and copying the virtual_device_ property with func.virtual_device_ = fn.virtual_device_ (like here: https://github.com/apache/tvm/blob/553eb1acd0c115adea0c7d04ce36e26332339769/tests/python/relay/test_pass_fold_constant.py#L102) produces the same assert. Changing the Function() API would probably touch a ton of code, so I’d like to avoid that. If I can figure out a way to solve the circular import, would the above patch be fine?

Yes, I’d very much support a patch to get you going again. I’m confused as to why just setting the virtual_device_ in the visitor directly does not work, so option a is you send me a unit test and I dig into that. Option b is your patch, however since you’ve needed to bounce back to c++ perhaps just register the function WithFields and use that instead of the Function ctor in visit_function (as we do on the c++ side).

I’ll confess it is only now that I realize we have copying in the python mutators which should have been reworked in parallel with the c++ mutators. Oops.