Fuse bias and relu to cudnn convolution

When I use cudnn, relay cannot fuse convolution with layers following it. So the convolution and the following bias、relu op are turn into two stages of computation, one for cuDNN call and the other for the rest of operations. However, cudnn has a interface cudnnConvolutionBiasActivationForward to do conv + bias + relu in one call, so how can I change the relay fuse logic to fuse bias and relu to cudnn convolution?

Hi there lain, welcome to the TVM community.

Could you share a standalone Python script that shows this behaviour? E.g. with a single layer DNN with the conv2d + bias + ReLU.

That will help folk quickly see if the problem can be reproduced, and make it easier to get to a solution. You can adapt this script I use, if that helps.

I would have thought that TVM supports this, but I have not really explored the cuDNN backend.

Hi Wheest, You can try the following script.

class Model(nn.Module):

def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros')
    self.relu = torch.nn.ReLU()

def forward(self, x):
    out = self.conv(x)
    out = self.relu(out)
    return out

inputs = torch.ones(1, 32, 128, 128)

model = Model().eval()

target = “cuda -libs=cudnn”

dev = tvm.cuda(0)

mod, params = relay.frontend.from_pytorch(torch.jit.trace(model, inputs), [(‘0’, inputs.shape)])

with tvm.transform.PassContext(opt_level = 3):

lib = relay.build(mod, target = target, target_host = 'c', params = params)

print(lib.lib.get_source())

=======================================================

From lib.lib.get_source(), we can see that the code first call tvm_contrib_cudnn_conv2d_forward_packed to do convolution by cudnn, then call tvmgen_default_fused_nn_conv2d_add_nn_relu_kernel0_packed to do bias and relu.

Of course if I do not use cudnn backend by changing the target from “cuda -libs=cudnn” to ‘cuda’, the conv、bias、relu can fused into one kernel.