I used relay.build
to generate a module and then extracted the C code using get_source
. However, the generated C code does not contain a definition of the overall model. Instead, it only includes individual functions corresponding to the model’s layers.
It seems that I need to call these functions in a specific order—likely the one defined internally by TVM—in order to reproduce the correct behavior of the model.
Is there a way to obtain the generated source code with this execution order explicitly included when using get_source
?
If not, how can I determine the correct order to call these functions manually when writing a custom main
function?
Below is an example script I used to test this behavior. I’d like to include it as part of my question for reference.
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.norm = nn.Linear(3, 1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.norm(x)
x = self.relu(x)
return x
if __name__ == '__main__':
input_shape = (1, 1, 3, 3)
sample_input = torch.randn(input_shape)
model = SimpleModel()
output = model(sample_input)
traced_model = torch.jit.trace(model, sample_input)
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(traced_model, shape_list)
with tvm.transform.PassContext(opt_level=0):
lib = relay.build(mod, target='c', params=params, mod_name="")
llvm_ir = lib.get_lib().get_source('c')
with open('simple_cnn.c','w') as f:
f.write(llvm_ir)
// tvm target: c -keys=cpu
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_broadcast_to(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_broadcast_to_1(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_nn_batch_matmul(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_nn_bias_add(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_nn_relu(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_reshape(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_reshape_1(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_reshape_2(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_squeeze(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_transpose(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_broadcast_to(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {
int32_t p0_code = arg_type_ids[0];
int32_t T_broadcast_to_code = arg_type_ids[1];
void* p0 = (((TVMValue*)args)[0].v_handle);
void* T_broadcast_to = (((TVMValue*)args)[1].v_handle);
void* tvmgen_fused_broadcast_to_p0_shape = (((DLTensor*)p0)[0].shape);
void* tvmgen_fused_broadcast_to_p0_strides = (((DLTensor*)p0)[0].strides);
int32_t dev_id = (((DLTensor*)p0)[0].device.device_id);
void* p0_1 = (((DLTensor*)p0)[0].data);
void* tvmgen_fused_broadcast_to_T_broadcast_to_shape = (((DLTensor*)T_broadcast_to)[0].shape);
void* tvmgen_fused_broadcast_to_T_broadcast_to_strides = (((DLTensor*)T_broadcast_to)[0].strides);
void* T_broadcast_to_1 = (((DLTensor*)T_broadcast_to)[0].data);
if (!(tvmgen_fused_broadcast_to_p0_strides == NULL)) {
}
if (!(tvmgen_fused_broadcast_to_T_broadcast_to_strides == NULL)) {
}
for (int32_t ax2 = 0; ax2 < 3; ++ax2) {
int32_t cse_var_1 = (ax2 * 3);
int32_t3 v_ = int32_t3((cse_var_1)+(1*0), (cse_var_1)+(1*1), (cse_var_1)+(1*2));
*(float3*)(((float*)T_broadcast_to_1) + cse_var_1) = (float3(((float*)p0_1)[v_.s0],((float*)p0_1)[v_.s1],((float*)p0_1)[v_.s2]));
}
return 0;
}
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t tvmgen_fused_broadcast_to_1(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {
int32_t p0_code = arg_type_ids[0];
int32_t T_broadcast_to_code = arg_type_ids[1];
void* p0 = (((TVMValue*)args)[0].v_handle);
void* T_broadcast_to = (((TVMValue*)args)[1].v_handle);
void* tvmgen_fused_broadcast_to_1_p0_shape = (((DLTensor*)p0)[0].shape);
void* tvmgen_fused_broadcast_to_1_p0_strides = (((DLTensor*)p0)[0].strides);
int32_t dev_id = (((DLTensor*)p0)[0].device.device_id);
void* p0_1 = (((DLTensor*)p0)[0].data);
void* tvmgen_fused_broadcast_to_1_T_broadcast_to_shape = (((DLTensor*)T_broadcast_to)[0].shape);
void* tvmgen_fused_broadcast_to_1_T_broadcast_to_strides = (((DLTensor*)T_broadcast_to)[0].strides);
void* T_broadcast_to_1 = (((DLTensor*)T_broadcast_to)[0].data);
if (!(tvmgen_fused_broadcast_to_1_p0_strides == NULL)) {
}
if (!(tvmgen_fused_broadcast_to_1_T_broadcast_to_strides == NULL)) {
}
for (int32_t ax2 = 0; ax2 < 3; ++ax2) {
((float*)T_broadcast_to_1)[ax2] = ((float*)p0_1)[ax2];
}
return 0;
}
...