While switching to TVMC, I noticed a “virtual_device” property on the top-level relay module function. It was not properly propagated through my relay passes and caused an assertion in lowering to TE, with:
Check failed: (!virtual_device->IsFullyUnconstrained()) is false
at:
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/__main__.py", line 24, in <module>
tvmc.main.main()
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 115, in main
sys.exit(_main(sys.argv[1:]))
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 103, in _main
return args.func(args)
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 173, in drive_compile
compile_model(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 337, in compile_model
graph_module = build(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 410, in build
return relay.build(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 431, in build
graph_json, runtime_mod, params = bld_mod.build(
File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 154, in build
self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, mod_name)
File "/home/user1/mlenv/deps/src/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
29: TVMFuncCall
28: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
27: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
25: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
24: tvm::transform::Pass::operator()(tvm::IRModule) const
23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
22: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
20: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
19: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec7LowerTEENS0_6StringENS_17CompilationConfigESt8functionIFvNS_8BaseFuncEEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
18: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
17: tvm::transform::Pass::operator()(tvm::IRModule) const
16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
14: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
13: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
12: _ZZN3tvm5relay11ExprFuncto
11: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
10: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
9: _ZN3tvm5relay9tr
8: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
7: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
6: _ZZN3tvm5relay11ExprFuncto
5: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode const*)
4: tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var const&, tvm::RelayExpr const&)
3: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
2: _ZZN3tvm5relay11ExprFuncto
1: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
0: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
File "/home/user1/mlenv/deps/src/tvm/src/relay/backend/te_compiler.cc", line 885
I noticed that this property is sometimes updated manually after creating new copies of a function:
However, this was not always done and I had to patch the following cases to fix the compilation again:
diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index 1a16cc9be..d05a30626 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -131,6 +131,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx)
// only process optimizable Relay Functions
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
Function updated_func = pass_func(GetRef<Function>(function_node), updated_mod, pass_ctx);
+ updated_func->virtual_device_ = GetRef<Function>(function_node)->virtual_device();
updates.push_back({kv.first, std::move(updated_func)});
}
}
diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index b9ca7d0e1..889031ed4 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -204,7 +204,10 @@ class ExprMutator(ExprFunctor):
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
- return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
+ func = Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs)
+ from tvm.relay.function import FunctionCopyVirtualDevice
+ FunctionCopyVirtualDevice(func, fn)
+ return func
def visit_let(self, let):
new_var = self.visit(let.var)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index f889f1e59..997fd1776 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -26,6 +26,10 @@
from . import _ffi_api
+def FunctionCopyVirtualDevice(f1, f2):
+ _ffi_api.FunctionCopyVirtualDevice(f1, f2)
+
+
@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 63e74144e..bd3906731 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -127,6 +127,10 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
return Function(params, body, ret_type, ty_params, attrs);
});
+TVM_REGISTER_GLOBAL("relay.ir.FunctionCopyVirtualDevice")
+ .set_body_typed([](Function f1, Function f2) {
+ f1->virtual_device_ = f2->virtual_device_;
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
This does not seem like an elegant solution and I’m wondering why the virtual_device is not part of the Function() python interface. Would that be an appropriate solution?