I’m trying to combine two relay function into one(not combine op). This is what I did:
x = relay.var("x", shape=(1, 8), dtype="float32")
add = relay.Function([x], relay.add(x, relay.const(1.0)))
y = relay.var("y", shape=(1, 8), dtype="float32")
mul = relay.Function([y], relay.multiply(y, relay.const(2.0)))
params = [relay.var(p.name_hint, p.type_annotation) for p in add.params]
combined_func = relay.Function(params, relay.Call(mul, [relay.Call(add, params)]))
print(combined_func)
mod = tvm.IRModule.from_expr(combined_func)
lib = relay.build(mod, target="llvm")
When I print out the combined_func
, it shows:
fn (%x: Tensor[(1, 8), float32]) {
%0 = fn (%x1: Tensor[(1, 8), float32]) {
add(%x1, 1f)
};
%1 = %0(%x);
%2 = fn (%y: Tensor[(1, 8), float32]) {
multiply(%y, 2f)
};
%2(%1)
}
This is what I want.
However, I encountered a TVM Error in relay.build
:
TVMError: Check failed: (prim_fns) is false: primitive functions not set on Relay function by TECompiler.
Here is the full stack trace:
tvm._ffi.base.TVMError: Traceback (most recent call last):
37: TVMFuncCall
36: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
35: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
34: void tvm::relay::backend::ExecutorCodegen::CallFunc<tvm::IRModule, tvm::relay::Function, tvm::runtime::String>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
33: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
32: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
31: tvm::transform::Pass::operator()(tvm::IRModule) const
30: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
29: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
27: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTE(tvm::runtime::String, tvm::CompilationConfig, std::function<void (tvm::BaseFunc)>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTE(tvm::runtime::String, tvm::CompilationConfig, std::function<void (tvm::BaseFunc)>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
25: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
24: tvm::transform::Pass::operator()(tvm::IRModule) const
23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
22: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
21: _ZN3tvm7runtime13PackedFuncObj
20: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTensorExpr(tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTensorExpr(tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
19: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
18: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
17: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
16: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
15: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
14: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwareVisit
13: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
12: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
11: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
10: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
9: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
8: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
7: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
6: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
5: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
4: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
3: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
2: _ZNSt17_Function_handlerIFvN3tvm8BaseFuncEEZNS0_5relay7backend20GraphExecutorCodegen7CodegenENS0_8IRModuleENS3_8FunctionENS0_7runtime6S
1: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)::{lambda(tvm::BaseFunc)#1}::operator()(tvm::BaseFunc) const
0: tvm::relay::tec::UpdateFunctionMetadata(tvm::BaseFunc, tvm::runtime::Map<tvm::runtime::String, tvm::relay::backend::FunctionInfo, void, void>&, tvm::Integer)
File "/home/ywtien/tvm/src/relay/backend/te_compiler.cc", line 1121
TVMError: Check failed: (prim_fns) is false: primitive functions not set on Relay function by TECompiler.
I don’t know what is the meaning of “primitive functions” in Relay. I’m not sure what I missed or if I’m on a right way, Does anyone have any suggestions?