[Relay] TVMError: primitive functions not set on Relay function by TECompiler

I’m trying to combine two relay function into one(not combine op). This is what I did:

x = relay.var("x", shape=(1, 8), dtype="float32")
add = relay.Function([x], relay.add(x, relay.const(1.0)))

y = relay.var("y", shape=(1, 8), dtype="float32")
mul = relay.Function([y], relay.multiply(y, relay.const(2.0)))

params = [relay.var(p.name_hint, p.type_annotation) for p in add.params]
combined_func = relay.Function(params, relay.Call(mul, [relay.Call(add, params)]))
print(combined_func)
  
mod = tvm.IRModule.from_expr(combined_func)
lib = relay.build(mod, target="llvm")

When I print out the combined_func, it shows:

fn (%x: Tensor[(1, 8), float32]) {
  %0 = fn (%x1: Tensor[(1, 8), float32]) {
    add(%x1, 1f)
  };
  %1 = %0(%x);
  %2 = fn (%y: Tensor[(1, 8), float32]) {
    multiply(%y, 2f)
  };
  %2(%1)
}

This is what I want. However, I encountered a TVM Error in relay.build:

TVMError: Check failed: (prim_fns) is false: primitive functions not set on Relay function by TECompiler.

Here is the full stack trace:

tvm._ffi.base.TVMError: Traceback (most recent call last):
  37: TVMFuncCall
  36: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  35: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  34: void tvm::relay::backend::ExecutorCodegen::CallFunc<tvm::IRModule, tvm::relay::Function, tvm::runtime::String>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
  33: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  32: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)
  31: tvm::transform::Pass::operator()(tvm::IRModule) const
  30: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  29: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  27: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTE(tvm::runtime::String, tvm::CompilationConfig, std::function<void (tvm::BaseFunc)>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTE(tvm::runtime::String, tvm::CompilationConfig, std::function<void (tvm::BaseFunc)>)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  25: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
  24: tvm::transform::Pass::operator()(tvm::IRModule) const
  23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  22: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  21: _ZN3tvm7runtime13PackedFuncObj
  20: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::tec::LowerTensorExpr(tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::tec::LowerTensorExpr(tvm::relay::tec::TECompiler, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  19: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  18: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  17: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  16: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  15: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
  14: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwareVisit
  13: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  12: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  11: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  10: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  9: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  8: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  7: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  6: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  5: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  4: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  3: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  2: _ZNSt17_Function_handlerIFvN3tvm8BaseFuncEEZNS0_5relay7backend20GraphExecutorCodegen7CodegenENS0_8IRModuleENS3_8FunctionENS0_7runtime6S
  1: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String)::{lambda(tvm::BaseFunc)#1}::operator()(tvm::BaseFunc) const
  0: tvm::relay::tec::UpdateFunctionMetadata(tvm::BaseFunc, tvm::runtime::Map<tvm::runtime::String, tvm::relay::backend::FunctionInfo, void, void>&, tvm::Integer)
  File "/home/ywtien/tvm/src/relay/backend/te_compiler.cc", line 1121
TVMError: Check failed: (prim_fns) is false: primitive functions not set on Relay function by TECompiler.

I don’t know what is the meaning of “primitive functions” in Relay. I’m not sure what I missed or if I’m on a right way, Does anyone have any suggestions?

I don’t remember off the top of my head, but primitive functions have an attribute set (kPrimitive, probably) and they’re produced by operator fusion. It’s what operators are lowered into during compilation. I would take a look at the fusion pass and its tests.

That said, your example looks correct. I’m surprised that you would encounter an error.

1 Like

I got the example to work using the VM (which is definitely the preferred way to execute Relay). Not sure why build() fails the way it did (I reproduced the crash).

import numpy as np

# build up the function and the IRModule as shown
ex = relay.create_executor("vm", device=tvm.cpu(), target="llvm", mod=mod).evaluate()
arr = np.random.rand(1, 8).astype("float32")
ex(arr) # executes and obtains a result
1 Like

@slyubomirsky @ywtien I am encountering a similar problem. If you were able to figure out why the TVM error occured, please share the solution.

Thanks in advance

In my case, I just added with_attr("Primitive", 1) to relay functions and it worked!

For example:

x = relay.var("x", shape=(1, 8), dtype="float32")
add = relay.Function([x], relay.add(x, relay.const(1.0))).with_attr("Primitive", 1)

y = relay.var("y", shape=(1, 8), dtype="float32")
mul = relay.Function([y], relay.multiply(y, relay.const(2.0))).with_attr("Primitive", 1)
3 Likes