How to force kernels to be fused or combined in TVM.

I have a simple PyTorch model:

class MyModel(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x @ x.T + x)

When I try to compile it down into CUDA with:


with tvm.transform.PassContext(opt_level=3,
                               config={
    "relay.FuseOps.max_depth": 30,
    "relay.backend.use_auto_scheduler": False,
} 
):
    lib = relay.build(mod, target, params=params)

rt_mod = lib.get_lib()
cuda_mod = rt_mod.imported_modules[0]
cuda_source = cuda_mod.get_source()
print(cuda_source)

I get this kernel


#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
        __shfl((var), (lane), (width))

#define __shfl_down_sync(mask, var, offset, width) \
        __shfl_down((var), (offset), (width))

#define __shfl_up_sync(mask, var, offset, width) \
        __shfl_up((var), (offset), (width))
#endif


#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
     (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif

#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(128) tvmgen_default_fused_squeeze_add_nn_relu_kernel(float* __restrict__ T_relu, float* __restrict__ p0, float* __restrict__ p1);
extern "C" __global__ void __launch_bounds__(64) tvmgen_default_fused_nn_dense_kernel(float* __restrict__ T_matmul_NT, float* __restrict__ p0);
extern "C" __global__ void __launch_bounds__(128) tvmgen_default_fused_squeeze_add_nn_relu_kernel(float* __restrict__ T_relu, float* __restrict__ p0, float* __restrict__ p1) {
  T_relu[((int)threadIdx.x)] = max((p0[0] + p1[((int)threadIdx.x)]), 0.000000e+00f);
}

extern "C" __global__ void __launch_bounds__(64) tvmgen_default_fused_nn_dense_kernel(float* __restrict__ T_matmul_NT, float* __restrict__ p0) {
  float T_matmul_NT_rf[1];
  __shared__ float red_result[1];
  T_matmul_NT_rf[0] = 0.000000e+00f;
  for (int k_outer = 0; k_outer < 2; ++k_outer) {
    T_matmul_NT_rf[0] = (T_matmul_NT_rf[0] + (p0[((k_outer * 64) + ((int)threadIdx.x))] * p0[((k_outer * 64) + ((int)threadIdx.x))]));
  }
  float red_buf0[1];
  uint mask[1];
  float t0[1];
  float red_buf0_1[1];
  uint mask_1[1];
  float t0_1[1];
  __shared__ float red_buf_staging[2];
  red_buf0_1[0] = T_matmul_NT_rf[0];
  mask_1[0] = __activemask();
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  if ((((int)threadIdx.x) % 32) == 0) {
    red_buf_staging[(((int)threadIdx.x) >> 5)] = red_buf0_1[0];
  }
  __syncthreads();
  if (((int)threadIdx.x) < 2) {
    red_buf0[0] = red_buf_staging[((int)threadIdx.x)];
  }
  mask[0] = (__activemask() & (uint)3);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  if (((int)threadIdx.x) == 0) {
    ((volatile float*)red_result)[0] = red_buf0[0];
  }
  __syncthreads();
  if (((int)threadIdx.x) == 0) {
    T_matmul_NT[0] = ((volatile float*)red_result)[0];
  }
}

However, this doesn’t help me a lot because although I can call the original PyTorch model in one line of code, I cannot do this for the CUDA code and must rather call it kernel by kernel while defining inputs and intermediate placeholders as required by the kernels. Is there a way to make TVM define a single function that I can call for the whole pipeline?