Hi @tqchen and @maheshambule,
Refer to TVM tutorial Bring Your Own Codegen To TVM, where details how to create a self-defined c source module codegen.
However, ONNX is not a C source module, we should define an ONNX module node for codegen additionally. The following is the steps how do I create ONNX module codegen.
- Create ONNXModuleCodegen to traverse relay IRModule and convert Relay op to ONNX op. (see codegen.cc)
class ONNXModuleCodegen {
public:
ONNXModuleCodegen(){}
runtime::Module CreateONNXModule(const ObjectRef& ref){
auto mod = Downcast<IRModule>(ref);
String codes = (*to_onnx_)(mod);
/* Here, use String instead of std::string, because some byte info
* will lost while calling PackedFunc object.
*/
const auto* pf = runtime::Registry::Get("runtime.ONNXModuleCreate");
CHECK(pf != nullptr) << "Cannot find onnx module to create the external runtime module";
return (*pf)(codes, "onnx");
}
private:
/*!
* \brief The python function to convert relay module to onnx module.
* \return byte array -> String
*/
const PackedFunc* to_onnx_ = runtime::Registry::Get("tvm.relay.converter.to_onnx");
};
Register a global function, “relay.ext.onnx”, whose body is a wrapper function to create onnx module.
runtime::Module ONNXCompiler(const ObjectRef& ref) {
ONNXModuleCodegen onnx;
return onnx.CreateONNXModule(ref);
}
TVM_REGISTER_GLOBAL("relay.ext.onnx").set_body_typed(ONNXCompiler);
Instead of writing op conversions in C++, use register_func to register a global function “tvm.relay.converter.to_onnx”, and write op conversions in Python to convert relay module to onnx module. (see converter/onnx.py)
@tvm.register_func("tvm.relay.converter.to_onnx")
def convert_to_onnx(model):
...
opset = onnx.defs.onnx_opset_version() # get the supported opset version
data = ""
global_vars = model.get_global_vars()
for i, global_var in enumerate(global_vars):
func = model[global_var]
sub_model = tvm.IRModule().from_expr(func.body)
sub_model = fuse_ops(sub_model)
func = sub_model["main"]
graph_name = global_var.name_hint
# Traverse the Relay function and record the nodes.
sub_onnx_model = ONNXGenerator({}, opset, graph_name, "").to_onnx(func)
bytes_data = get_onnx_bytes(sub_onnx_model)
data += graph_name +"<"+ str(bytes_data)[2:-1]+">";
return data
- Create ONNXModuleNode which is subclass of ModuleNode to create a specific runtime module. (see source_module.cc)
class ONNXModuleNode : public runtime::ModuleNode {
public:
ONNXModuleNode(std::string code,
std::string fmt)
: code_(code), fmt_(fmt) {}
const char* type_key() const {
return "onnx";
}
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "ONNX Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
std::string GetSource(const std::string& format) final {
return code_;
}
...
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string folder;
size_t pos = file_name.find_last_of("\\/");
if(pos!=std::string::npos){
folder = file_name.substr(0,pos+1);
}else{
folder = file_name+"/";
}
auto datas = Split(code_,'>');
if (fmt == "onnx") {
CHECK_NE(code_.size(), 0);
std::stringstream ss;
for(auto data : datas){
auto split_data = Split(data,'<');
ss<<folder<<split_data[0].c_str()<<"."<<fmt;
SaveBinaryToFile(ss.str(), ConvertEscape(split_data[1]));
ss.str("");
ss.clear();
}
} else {
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
}
}
...
runtime::Module ONNXModuleCreate(String code, std::string fmt) {
/* Here, use String instead of std::string, because some byte info
* will lost while calling PackedFunc object.
*/
auto n = make_object<ONNXModuleNode>(code, fmt);
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("runtime.ONNXModuleCreate")
.set_body_typed(ONNXModuleCreate);
- Create a cmake file for ONNX, and include cmake in CMakeLists.txt. When user wants to use onnx module codegen, set USE_ONNX_CODEGEN “ON”, and build TVM source code. (see ONNX.cmake)
if(USE_ONNX_CODEGEN STREQUAL "ON")
file(GLOB ONNX_RELAY_CONTRIB_SRC src/relay/backend/contrib/onnx/codegen.cc)
list(APPEND COMPILER_SRCS ${ONNX_RELAY_CONTRIB_SRC})
message(STATUS "Build with ONNX codegen.")
endif()
In addition, I update my code to the Apail 28 version. (see source code and example code )