If we use cuda with cublas, and just create a dense and add layer in the mode. Because the nn.dense 's OpPattern is OUT_ELEMWISE_FUSABLE, the dense and add layer will be fused as a function. This process occurs in the phase of FuseOps
However, because we use cublas, the dense layer will use cublas implementation and the add layer will use normal codegen. In this case, the generated add kernel is very naive and inefficient.
We can verify the result by the following code
import tvm
from tvm import relay
import numpy as np
target = tvm.target.cuda(arch="sm_86", options="-libs=cublas")
dev = tvm.device(target.kind.name, 0)
dtype = "float16"
A = relay.var("A", shape=(16384, 256), dtype=dtype)
weight = relay.var("A", shape=(256, 256), dtype=dtype)
B = relay.var("B", shape=(1, 256), dtype=dtype)
out = relay.nn.dense(A, weight)
out = relay.add(out, B)
free_vars = relay.analysis.free_vars(out)
func = relay.Function(free_vars, out)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
print("mod:\n")
print(mod)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)
dev_module = lib.get_lib().imported_modules[0]
print("cu:\n")
print(dev_module.get_source("cu"))
graph_module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
print("Evaluate inference igie multi-kernel time cost...")
ftimer = graph_module.module.time_evaluator("run", dev, number=1, repeat=10)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))
If we set dense OpPattern to OPAQUE like the bellow
from tvm.ir.op import Op
dense_op = Op.get("nn.dense")
dense_op.set_attr("TOpPattern", tvm.relay.op.OpPattern.OPAQUE, 30)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)
and run the test again, the generated add kernel will be different and the total run time will be small.
I think there should be some strategy to decide the codegen here to avoid the inefficient kernel.