[Question] Bad codegen result in dense + add fusion when using cublas

If we use cuda with cublas, and just create a dense and add layer in the mode. Because the nn.dense 's OpPattern is OUT_ELEMWISE_FUSABLE, the dense and add layer will be fused as a function. This process occurs in the phase of FuseOps

However, because we use cublas, the dense layer will use cublas implementation and the add layer will use normal codegen. In this case, the generated add kernel is very naive and inefficient.

We can verify the result by the following code

import tvm
from tvm import relay
import numpy as np

target = tvm.target.cuda(arch="sm_86", options="-libs=cublas")
dev = tvm.device(target.kind.name, 0)


dtype = "float16"

A = relay.var("A", shape=(16384, 256), dtype=dtype)
weight = relay.var("A", shape=(256, 256), dtype=dtype)
B = relay.var("B", shape=(1, 256), dtype=dtype)

out = relay.nn.dense(A, weight)
out = relay.add(out, B)

free_vars = relay.analysis.free_vars(out)
func = relay.Function(free_vars, out)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
print("mod:\n")
print(mod)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target)
    
dev_module = lib.get_lib().imported_modules[0]
print("cu:\n")
print(dev_module.get_source("cu"))

graph_module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

print("Evaluate inference igie multi-kernel time cost...")
ftimer = graph_module.module.time_evaluator("run", dev, number=1, repeat=10)
prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))


If we set dense OpPattern to OPAQUE like the bellow

from tvm.ir.op import Op

dense_op = Op.get("nn.dense")
dense_op.set_attr("TOpPattern", tvm.relay.op.OpPattern.OPAQUE, 30)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target)
    

and run the test again, the generated add kernel will be different and the total run time will be small.

I think there should be some strategy to decide the codegen here to avoid the inefficient kernel.

codegen and perf result without setting dense OpPattern to OPAQUE

extern "C" __global__ void __launch_bounds__(1024) tvmgen_default_fused_nn_dense_add_kernel0(half* __restrict__ T_add, half* __restrict__ matmul_cublas, half* __restrict__ p2) {
  T_add[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] = (matmul_cublas[((((int)blockIdx.x) * 1024) + ((int)threadIdx.x))] + p2[(((int)threadIdx.x) & 255)]);
}


Evaluate inference igie multi-kernel time cost...
Mean inference time (std dev): 0.10 ms (0.01 ms)
}


codegen and perf result with setting dense OpPattern to OPAQUE

extern "C" __global__ void __launch_bounds__(1024) tvmgen_default_fused_add_kernel0(half* __restrict__ T_add, half* __restrict__ p0, half* __restrict__ p1) {
  for (int ax0_ax1_fused_outer_outer = 0; ax0_ax1_fused_outer_outer < 8; ++ax0_ax1_fused_outer_outer) {
    uint1 __1;
      uint1 __2 = *(uint1*)(p0 + (((ax0_ax1_fused_outer_outer * 524288) + (((int)blockIdx.x) * 2048)) + (((int)threadIdx.x) * 2)));
      uint1 __3 = *(uint1*)(p1 + ((((int)threadIdx.x) & 127) * 2));
      ((half2*)(&(__1.x)))->x = (((half2*)(&(__2.x)))->x+((half2*)(&(__3.x)))->x);
      ((half2*)(&(__1.x)))->y = (((half2*)(&(__2.x)))->y+((half2*)(&(__3.x)))->y);
    *(uint1*)(T_add + (((ax0_ax1_fused_outer_outer * 524288) + (((int)blockIdx.x) * 2048)) + (((int)threadIdx.x) * 2))) = __1;
  }
}


Evaluate inference igie multi-kernel time cost...
Mean inference time (std dev): 0.09 ms (0.01 ms)