Error occured to EfficientNet in structure after prerequisite_optimize pass

Hi, I try to use prerequisite_optimize to fuse batch-norm layers into convolution layers. However, after this pass, I got a network that seems inconsistent with the original one as shown below.

The original network:

  %0 = nn.conv2d(%input0, %features.0.0.weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);
  %1 = nn.batch_norm(%0, %features.0.1.weight, %features.0.1.bias, %features.0.1.running_mean, %features.0.1.running_var, epsilon=0.001f);
  %2 = %1.0;
  %3 = sigmoid(%2);
  %4 = multiply(%2, %3);
  %5 = reshape(%features.1.0.block.0.0.weight, newshape=[64, 1, 3, 3]);
  %6 = nn.conv2d(%4, %5, padding=[1, 1, 1, 1], groups=64, channels=64, kernel_size=[3, 3]);
  %7 = nn.batch_norm(%6, %features.1.0.block.0.1.weight, %features.1.0.block.0.1.bias, %features.1.0.block.0.1.running_mean, %features.1.0.block.0.1.running_var, epsilon=0.001f);
  %8 = %7.0;
  %9 = sigmoid(%8);
  %10 = multiply(%8, %9);
  %11 = nn.adaptive_avg_pool2d(%10, output_size=[1, 1]);
  %12 = nn.conv2d(%11, %features.1.0.block.1.fc1.weight, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]);
  %13 = nn.bias_add(%12, %features.1.0.block.1.fc1.bias);
  %14 = sigmoid(%13);
  %15 = multiply(%13, %14);
  %16 = nn.conv2d(%15, %features.1.0.block.1.fc2.weight, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]);
  %17 = nn.bias_add(%16, %features.1.0.block.1.fc2.bias);
  %18 = sigmoid(%17);
  %19 = multiply(%18, %10);
  %20 = nn.conv2d(%19, %features.1.0.block.2.0.weight, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]);
  %21 = nn.batch_norm(%20, %features.1.0.block.2.1.weight, %features.1.0.block.2.1.bias, %features.1.0.block.2.1.running_mean, %features.1.0.block.2.1.running_var, epsilon=0.001f);
  %22 = %21.0;
  %23 = reshape(%features.1.1.block.0.0.weight, newshape=[32, 1, 3, 3]);
  %24 = nn.conv2d(%22, %23, padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]);
  %25 = nn.batch_norm(%24, %features.1.1.block.0.1.weight, %features.1.1.block.0.1.bias, %features.1.1.block.0.1.running_mean, %features.1.1.block.0.1.running_var, epsilon=0.001f);
  %26 = %25.0;
  %27 = sigmoid(%26);

The optimized network:

The problem can be reproduced by this code:

import torch
from torchvision import models
import tvm.relay as relay
from tvm.relay.quantize import prerequisite_optimize
import tvm

model = models.efficientnet_b7(pretrained=True)
model = model.eval()

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

input_name = "input0"
shape_list = [(input_name, (1, 3, 224, 224))]

mod_original, params = relay.frontend.from_pytorch(scripted_model, shape_list)
print(mod_original)

with tvm.transform.PassContext(opt_level=3):
    optimized_mod = prerequisite_optimize(mod_original, params)
print(optimized_mod)

I found that FoldScaleAxis in prerequisite_optimize may cause this problem.