[PyTorch] Batch norm result mismatching

Although PyTorch BatchNorm2D can be converted to Relay nn.batch_norm, I found that the results produced by PyTorch BatchNorm2D and converted Relay batch_norm are different. Here is the testing script:

import numpy as np
import torch

import tvm
from tvm import relay

data = np.random.uniform(-1, 1, (1, 3, 224, 224)).astype("float32")

# Pytorch model
class TorchBN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        return self.bn(x)

t_model = TorchBN()

t_data = torch.Tensor(data)

with torch.no_grad():
    t_out = t_model(t_data)

# Convert to Relay
scripted_model = torch.jit.trace(t_model, t_data).eval()
mod, params = relay.frontend.from_pytorch(scripted_model, [("input0", (1, 3, 224, 224))])

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target="llvm", params=params)

ctx = tvm.cpu(0)
m = tvm.contrib.graph_runtime.GraphModule(lib["default"](ctx))
m.set_input("input0", tvm.nd.array(data))
m.run()
r_out = m.get_output(0)

# Compare results
t_np = t_out.numpy()
r_np = r_out.asnumpy()
np.testing.assert_allclose(t_np, r_np, rtol=1e-4, atol=1e-4)

And here is the sample outupt:

Not equal to tolerance rtol=0.0001, atol=0.0001

Mismatched elements: 150503 / 150528 (100%)
Max absolute difference: 0.53077245
Max relative difference: 393.49854
 x: array([[[[ 0.824763,  1.193627, -1.391356, ...,  0.45779 , -0.425513,
          -1.232682],
         [-0.102098, -0.824518,  0.969305, ..., -0.644249, -1.359166,...
 y: array([[[[ 0.574541,  0.831002, -0.966262, ...,  0.319396, -0.294739,
          -0.85594 ],
         [-0.069878, -0.572156,  0.675038, ..., -0.44682 , -0.943881,...

The results look quite different, so I don’t think it’s due to the floating point errors…I’m not sure if it’s the issue from the PyTorch frontend or the way I convert the model. Does anyone have any idea?

cc @masahi @yzhliu @hzfan

2 Likes

Add one line t_model.eval() under t_model = TorchBN(). It works.

Eval mode graph:

# print(scripted_model.inlined_graph)
graph(%self.1 : __torch__.___torch_mangle_11.TorchBN,
      %input : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):
  %2 : __torch__.torch.nn.modules.batchnorm.___torch_mangle_10.BatchNorm2d = prim::GetAttr[name="bn"](%self.1)
  %4 : bool = prim::Constant[value=1](), scope: __module.bn
  %5 : float = prim::Constant[value=1.0000000000000001e-05](), scope: __module.bn
  %6 : float = prim::Constant[value=0.10000000000000001](), scope: __module.bn
  %7 : bool = prim::Constant[value=0](), scope: __module.bn
  %8 : Tensor = prim::GetAttr[name="running_var"](%2)
  %9 : Tensor = prim::GetAttr[name="running_mean"](%2)
  %10 : Tensor = prim::GetAttr[name="bias"](%2)
  %11 : Tensor = prim::GetAttr[name="weight"](%2)
  %12 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=1, device=cpu) = aten::batch_norm(%input, %11, %10, %9, %8, %7, %6, %5, %4), scope: __module.bn
  return (%12)

Train mode graph:

# print(scripted_model.inlined_graph)
graph(%self.1 : __torch__.___torch_mangle_3.TorchBN,
      %input : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):
  %2 : __torch__.torch.nn.modules.batchnorm.___torch_mangle_2.BatchNorm2d = prim::GetAttr[name="bn"](%self.1)
  %4 : float = prim::Constant[value=1.0000000000000001e-05](), scope: __module.bn
  %5 : float = prim::Constant[value=0.10000000000000001](), scope: __module.bn
  %6 : bool = prim::Constant[value=1](), scope: __module.bn
  %7 : Tensor = prim::GetAttr[name="running_var"](%2)
  %8 : Tensor = prim::GetAttr[name="running_mean"](%2)
  %9 : Tensor = prim::GetAttr[name="bias"](%2)
  %10 : Tensor = prim::GetAttr[name="weight"](%2)
  %11 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=1, device=cpu) = aten::batch_norm(%input, %10, %9, %8, %7, %6, %5, %4, %6), scope: __module.bn
  return (%11)

The 6th param of aten::batch_norm is bool training.

1 Like

Thanks @hgt312 for the explanation!