There are two types of fusion happening in the case of conv+bn+relu.
First, is conv+bn. Here, because we’re running inference we can actually combine the parameters of the BN with the weights of the conv2d. If you write out formula for batch norm being applied to conv2d, you can see the simplification yourself. In this way the batch norm operation is completely removed.
For the ReLU fusion, we still need to run it, but we can avoid doing an extra set of nested for loops.
Here’s a simple example showing how we can fuse a matmul and a ReLU.
Non-fused version
for ni in range(N): # GEMM
for mi in range(M):
for ki in range(K):
c[ni, mi] += (
a[ni, ki] * b[ki, mi]
)
for ni in range(N): # ReLU
for mi in range(M):
c[ni, mi] = relu(c[ni, mi])
And the fused version:
for ni in range(N): # GEMM+ReLU
for mi in range(M):
for ki in range(K):
c[ni, mi] += (
a[ni, ki] * b[ki, mi]
)
c[ni, mi] = relu(c[ni, mi])
You can see a TVM example with PyTorch below, where we can disable the OpFusion pass and compare the inference times. Since it’s such a tiny layer, if you have a fast machine you might need to increase the size or add more layers to see the difference.
On my machine I got:
Fused
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
66.5262 65.2542 143.2263 49.2954 10.4740
Unfused
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
75.3485 69.1626 291.9446 44.1644 29.4445
PyTorch WeeNet with Fusion
#!/usr/bin/env python3
import torch
import torch.nn.functional as F
import torch.nn as nn
import tvm
from tvm import relay
from tvm.contrib import graph_executor
class WeeNet(nn.Module):
def __init__(self, planes=32):
super(WeeNet, self).__init__()
# depthwise convolution
self.conv1 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=1,
padding=1,
)
self.bn1 = nn.BatchNorm2d(planes)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
return out
def get_model():
planes = 64
model = WeeNet(planes)
input_shape = (1, planes, 128, 128)
input_name = "data"
input_dtype = "float32"
input_data = torch.rand((input_shape))
scripted_model = torch.jit.trace(model, input_data).eval()
scripted_model(input_data)
relay_mod, params = relay.frontend.from_pytorch(
scripted_model, [(input_name, input_shape)]
)
x = input_data.numpy()
# y = y.detach().numpy()
sample_inputs = [x]
return relay_mod, params, input_name, input_dtype, input_shape, sample_inputs
def compile():
mod, params, input_name, input_dtype, input_shape, sample_inputs = get_model()
TARGET = tvm.target.Target("llvm")
# mod.show()
m = compile_tvm_model(mod, params)
m.set_input(input_name, sample_inputs[0])
m.run()
m.get_output(0)
dev = tvm.device(str(TARGET), 0)
print("Fused")
print(m.benchmark(dev, number=1, repeat=600))
m = compile_tvm_model(mod, params, disable_passes=["OpFusion", "FoldConstant"])
m.set_input(input_name, sample_inputs[0])
m.run()
m.get_output(0)
dev = tvm.device(str(TARGET), 0)
print("\nUnfused")
print(m.benchmark(dev, number=1, repeat=600))
def compile_tvm_model(mod, params, disable_passes=[]):
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3, disabled_pass=disable_passes):
lib = tvm.relay.build(mod, target=target, params=params)
m = graph_executor.GraphModule(lib["default"](dev))
return m
compile()