Trying to mimic tests/cpp/relay_build_module_test.cc in construct a simple Dense + Relu + Add function as below.
auto tensor_type_f32_16_8 = relay::TensorType({16, 8}, DataType::Float(32));
auto tensor_type_f32_8_8 = relay::TensorType({8, 8}, DataType::Float(32));
auto a = relay::Var("a", tensor_type_f32_16_8);
auto b = relay::Var("b", tensor_type_f32_8_8);
auto c = relay::Var("c", tensor_type_f32_16_8);
auto dense_op = relay::Op::Get("nn.dense");
auto x = relay::Call(dense_op, {a, b}, tvm::Attrs(), {});
auto relu_op = relay::Op::Get("nn.relu");
auto y = relay::Call(relu_op, {x}, tvm::Attrs(), {});
auto add_op = relay::Op::Get("add");
auto z = relay::Call(add_op, {y, c}, tvm::Attrs(), {});
auto func = relay::Function(relay::FreeVars(z), z, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({16, 8}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({8, 8}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({16, 8}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto pA = (float*)A.ToDLPack()->dl_tensor.data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 128; ++i) {
pA[i] = i;
pC[i] = i + 2;
if (i<64)
pB[i] = i + 1;
}
...
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::Create("llvm");
targets.Set(0, llvm_tgt);
auto relay_mod = tvm::IRModule::FromExpr(func);
build_f(relay_mod, targets, llvm_tgt);
It error out in the step of tvm::IRModule::FromExpr(func):
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from Relay
[ RUN ] Relay.BuildModule
[09:56:53] /work/git_repo/tvm/src/printer/doc.cc:55: text node: ' an internal invariant was violated while typechecking your program [09:56:53] /work/git_repo/tvm/src/relay/op/nn/nn.h:41: Check failed: param != nullptr:
Is there any special setup for params needed in function construction code? Thanks.