How to get batch norm's running stats such as running_var and running_mean in pytorch and pass them into relay's batch norm api

Hi friends: I have some trouble with finding some bugs in my codes. Since I am trying to match batch norm’s outputs between pytorch and tvm. I am confused that how to get running stats in pytorch and set it into relay batch norm so that they can get the same outputs.

Here is my batch norm function:

def batch_norm_infer(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs):

    name = kwargs.get("name")
    kwargs.pop("name")
    if not gamma:
        gamma = relay.var(name + "_gamma")
    if not beta:
        beta = relay.var(name + "_beta")
    if not moving_mean:
        moving_mean = relay.mean(data, axis=[0,2,3])
        # moving_mean = relay.var(name + "_moving_mean")
    if not moving_var:
        moving_var = relay.variance(data, axis=[0,2,3])
        # moving_var = relay.var(name + "_moving_var")
    return relay.nn.batch_norm(
        data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs
    )[0]

And this is my entire script:

import torch, argparse, os, sys
import logging
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
import time
import numpy as np
from tvm.contrib import graph_executor as runtime
from tvm import relay
from tvm.relay import testing
import tvm.testing
import math
import tvm.topi.testing
from tvm import relay, te
from tvm.relay.testing import run_infer_type


import torch
from torch import nn

logging.basicConfig(level=logging.DEBUG)


def compare(v1, v2):
    return np.max(np.abs(v1 - v2))
    

def batch_norm_infer(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs):

    name = kwargs.get("name")
    kwargs.pop("name")
    if not gamma:
        gamma = relay.var(name + "_gamma")
    if not beta:
        beta = relay.var(name + "_beta")
    if not moving_mean:
        moving_mean = relay.mean(data, axis=[0,2,3])
        # moving_mean = relay.var(name + "_moving_mean")
    if not moving_var:
        moving_var = relay.variance(data, axis=[0,2,3])
        # moving_var = relay.var(name + "_moving_var")
    return relay.nn.batch_norm(
        data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs
    )[0]


def run_test_conv2d_batch_norm(
    dtype,
    out_dtype,
    scale,
    dshape,
    kshape,
    padding=(1, 1),
    fref=None,
    groups=1,
    dilation=(1, 1),
    except_targets=None,
    **attrs,
):
    x = relay.var("data", relay.TensorType(dshape, "float32"))
    w = relay.var("kernel", relay.TensorType(kshape, "float32"))
    y = relay.nn.conv2d(data = x, weight = w, padding=padding, dilation=dilation, groups=groups, **attrs)
    out = batch_norm_infer(y, name = "my_batch_norm")
    simple_net = relay.Function(relay.analysis.free_vars(out), out)

    params = {}
    mod = tvm.IRModule.from_expr(simple_net)
    mod = relay.transform.InferType()(mod)
    print(mod['main'])

    extract_params = []
    extract_shapes = []
    for v in mod["main"].params:
        str_v = str(v)
        index1 = str_v.find("%")
        index2 = str_v.find(":")
        extract_params.append(str_v[index1+1:index2])

        index1 = str_v.find("[(")
        index1+=2
        index2 = str_v.find("),")
        tensor_shape_str = str_v[index1:index2]
        tensor_shape_str = tensor_shape_str.replace(" ", "")
        array = tensor_shape_str.split(",")
        shape = []
        for ele in array:
            shape.append(int(ele))
        extract_shapes.append(shape)
    
    print("get all of parameters of relay model ... ")
    for i in range(len(extract_shapes)):
        print(extract_params[i])
        print(extract_shapes[i])
    # net, params = testing.create_workload(simple_net)

    kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
    data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
    target = "cuda"
    dev = tvm.device(target, 0)


    in_channels = data.shape[1]
    out_channels = kernel.shape[0]


    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                        kernel_size = 3, padding=1, bias = False)
            self.bn = nn.BatchNorm2d(out_channels)

        def forward(self, x):
            x = self.conv1(x)
            x = self.bn(x)
            return x
    
    my_net = Net()
    
    with torch.no_grad():
        my_net.conv1.weight = torch.nn.Parameter(torch.from_numpy(kernel))


    print("torch parameters: ")
    torch_params = {}
    for name, param in my_net.named_parameters():
        print(name)
        print(param)
        torch_params[name] = param

    params = {}
    # transfer parameters from torch model to relay model
    for i in range(len(extract_params)):
        if extract_params[i] == 'kernel':
            np_array = torch_params["conv1.weight"].cpu().detach().numpy()
            params['kernel'] = tvm.nd.array(np_array)
        elif extract_params[i] == "my_batch_norm_gamma":
            np_array = torch_params['bn.weight'].cpu().detach().numpy()
            params['my_batch_norm_gamma'] = tvm.nd.array(np_array)
        elif extract_params[i] == "my_batch_norm_beta":
            np_array = torch_params['bn.bias'].cpu().detach().numpy()
            params['my_batch_norm_beta'] = tvm.nd.array(np_array)
        elif extract_params[i] == "my_batch_norm_moving_mean":
            relay_shape = extract_shapes[i]
            np_array = np.zeros(relay_shape, dtype="float32")
            params['my_batch_norm_moving_mean'] = tvm.nd.array(np_array)
        elif extract_params[i] == "my_batch_norm_moving_var":
            relay_shape = extract_shapes[i]
            np_array = np.ones(relay_shape, dtype="float32")
            params['my_batch_norm_moving_var'] = tvm.nd.array(np_array)

    for name in params:
        print(name)
        print(params[name])


    input = torch.from_numpy(data)
    # my_net.eval()
    torch_output = my_net(input)
    output = torch_output.detach().numpy()


    print("torch final output: ")
    print(output)


    lib = relay.build_module.build(mod, target, params=params)
    module = runtime.GraphModule(lib["default"](dev))
    module.set_input("data", data)
    # module.set_input("kernel", kernel)
    module.run()
    # out_shape = (1, 16, 18, 18)
    out = module.get_output(0)
    op_res1 = out.asnumpy()
    print('tvm final batch norm\'s shape : ')
    print(op_res1.shape)

    print("tvm final output: ")
    print(op_res1)


    for i in range(output.shape[0]):
        for j in range(output.shape[1]):
            for k in range(output.shape[2]):
                for q in range(output.shape[3]):
                    if abs(output[i,j,k,q] - op_res1[i,j,k,q]) >= 0.001:
                        print(output[i,j,k,q])
                        print(op_res1[i,j,k,q]) 





if __name__ == '__main__':
    print("checking conv2d op between torch and relay... ")
    dshape = (1, 3, 4, 4)
    kshape = (3, 3, 3, 3)
    
    run_test_conv2d_batch_norm(
        "float32",
        "float32",
        1,
        dshape,
        kshape,
        padding=(1, 1),
        channels=kshape[0],
        groups=1,
        kernel_size=(3, 3)
        )

when I set my_net.eval(), I got different results. I am confused how I should get running_var and running_mean so that they can get the same result.