Advice on understanding Operator Fusion

Hello everyone, this might come off as a basic question. I am trying to understand operator fusion better, I have been reading the paper and it did make sense. I tried a few toy examples and saw the performance improvements. I would also like to try it on actual networks, I noticed the benchmarks were for:

conv+bn+relu 28x28x28 xx28x256 depthwise- conv+bn+relu 52x4x4 3x3x52 rnn cell hidden:28 lstm cell hidden:28

I am kind of confused as to what would be fused and non-fused operations for say conv+bn+relu? Would be grateful if someone maybe attached simple pytorch snippet. Thank you,

There are two types of fusion happening in the case of conv+bn+relu.

First, is conv+bn. Here, because we’re running inference we can actually combine the parameters of the BN with the weights of the conv2d. If you write out formula for batch norm being applied to conv2d, you can see the simplification yourself. In this way the batch norm operation is completely removed.

For the ReLU fusion, we still need to run it, but we can avoid doing an extra set of nested for loops.

Here’s a simple example showing how we can fuse a matmul and a ReLU.

Non-fused version

for ni in range(N): # GEMM
  for mi in range(M):
    for ki in range(K):
      c[ni, mi] += (
        a[ni, ki] * b[ki, mi]
      )
for ni in range(N): # ReLU
  for mi in range(M):
    c[ni, mi] = relu(c[ni, mi])

And the fused version:

for ni in range(N): # GEMM+ReLU
  for mi in range(M):
    for ki in range(K):
      c[ni, mi] += (
        a[ni, ki] * b[ki, mi]
      )
  c[ni, mi] = relu(c[ni, mi])

You can see a TVM example with PyTorch below, where we can disable the OpFusion pass and compare the inference times. Since it’s such a tiny layer, if you have a fast machine you might need to increase the size or add more layers to see the difference.

On my machine I got:

Fused
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  66.5262      65.2542      143.2263     49.2954      10.4740                  

Unfused
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
  75.3485      69.1626      291.9446     44.1644      29.4445 
PyTorch WeeNet with Fusion
#!/usr/bin/env python3
import torch
import torch.nn.functional as F
import torch.nn as nn

import tvm
from tvm import relay
from tvm.contrib import graph_executor


class WeeNet(nn.Module):
    def __init__(self, planes=32):
        super(WeeNet, self).__init__()
        # depthwise convolution
        self.conv1 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.bn1 = nn.BatchNorm2d(planes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        return out


def get_model():
    planes = 64
    model = WeeNet(planes)
    input_shape = (1, planes, 128, 128)
    input_name = "data"
    input_dtype = "float32"
    input_data = torch.rand((input_shape))

    scripted_model = torch.jit.trace(model, input_data).eval()
    scripted_model(input_data)

    relay_mod, params = relay.frontend.from_pytorch(
        scripted_model, [(input_name, input_shape)]
    )

    x = input_data.numpy()
    # y = y.detach().numpy()
    sample_inputs = [x]
    return relay_mod, params, input_name, input_dtype, input_shape, sample_inputs


def compile():
    mod, params, input_name, input_dtype, input_shape, sample_inputs = get_model()
    TARGET = tvm.target.Target("llvm")

    # mod.show()

    m = compile_tvm_model(mod, params)
    m.set_input(input_name, sample_inputs[0])
    m.run()
    m.get_output(0)
    dev = tvm.device(str(TARGET), 0)
    print("Fused")
    print(m.benchmark(dev, number=1, repeat=600))

    m = compile_tvm_model(mod, params, disable_passes=["OpFusion", "FoldConstant"])
    m.set_input(input_name, sample_inputs[0])
    m.run()
    m.get_output(0)
    dev = tvm.device(str(TARGET), 0)
    print("\nUnfused")
    print(m.benchmark(dev, number=1, repeat=600))


def compile_tvm_model(mod, params, disable_passes=[]):
    target = tvm.target.Target("llvm", host="llvm")
    dev = tvm.cpu(0)
    with tvm.transform.PassContext(opt_level=3, disabled_pass=disable_passes):
        lib = tvm.relay.build(mod, target=target, params=params)
    m = graph_executor.GraphModule(lib["default"](dev))
    return m


compile()
1 Like