[Bug] Relay.transform.FoldScaleAxis() Error in folding the SqueezeExcitation structure of mobilenetv3.

frontend code:

import torch
from torch import nn, Tensor
from torch.nn import functional as F

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class SqueezeExcitation(nn.Module):
    def __init__(self, input_channels: int, squeeze_factor: int = 4):
        super().__init__()
        squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
        self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)

    def _scale(self, input: Tensor, inplace: bool) -> Tensor:
        scale = F.adaptive_avg_pool2d(input, 1)
        scale = self.fc1(scale)
        scale = self.relu(scale)
        scale = self.fc2(scale)
        return F.hardsigmoid(scale, inplace=inplace)

    def forward(self, input: Tensor) -> Tensor:
        scale = self._scale(input, True)
        return scale * input

class M(nn.Module):
    def __init__(self, input_channels: int=16):
        super().__init__()
        self.conv = nn.Conv2d(input_channels, 64, 1, bias=False)
        self.se_layer = SqueezeExcitation(input_channels)

    def forward(self, x: Tensor) -> Tensor:
        x = self.se_layer(x)
        x = self.conv(x)
        return x

name = "x"
shape = (1, 16, 64, 48)
data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
data_torch = torch.from_numpy(data_np)

model = M(input_channels=16).eval()
scripted_model = torch.jit.trace(model, data_torch).eval()
shape_list = [(name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

first see:

print(mod["main"])

reslut:

fn (%x: Tensor[(1, 16, 64, 48), float32] /* span=aten::adaptive_avg_pool2d_0.x:0:0 */, %aten::_convolution_0.weight: Tensor[(8, 16, 1, 1), float32] /* span=aten::_convolution_0.weight:0:0 */, %aten::_convolution_0.bias: Tensor[(8), float32] /* span=aten::_convolution_0.bias:0:0 */, %aten::_convolution_1.weight: Tensor[(16, 8, 1, 1), float32] /* span=aten::_convolution_1.weight:0:0 */, %aten::_convolution_1.bias: Tensor[(16), float32] /* span=aten::_convolution_1.bias:0:0 */, %aten::_convolution_2.weight: Tensor[(64, 16, 1, 1), float32] /* span=aten::_convolution_2.weight:0:0 */) {
  %0 = nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* span=aten::adaptive_avg_pool2d_0:0:0 */;
  %1 = nn.conv2d(%0, %aten::_convolution_0.weight, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* span=aten::_convolution_0:0:0 */;
  %2 = nn.bias_add(%1, %aten::_convolution_0.bias) /* span=aten::_convolution_0:0:0 */;
  %3 = nn.relu(%2) /* span=aten::relu__0:0:0 */;
  %4 = nn.conv2d(%3, %aten::_convolution_1.weight, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* span=aten::_convolution_1:0:0 */;
  %5 = nn.bias_add(%4, %aten::_convolution_1.bias) /* span=aten::_convolution_1:0:0 */;
  %6 = add(%5, 3f /* span=aten::hardsigmoid__0:0:0 */) /* span=aten::hardsigmoid__0:0:0 */;
  %7 = clip(%6, a_min=0f, a_max=6f) /* span=aten::hardsigmoid__0:0:0 */;
  %8 = divide(%7, 6f /* span=aten::hardsigmoid__0:0:0 */) /* span=aten::hardsigmoid__0:0:0 */;
  %9 = multiply(%8, %x) /* span=aten::mul_0:0:0 */;
  nn.conv2d(%9, %aten::_convolution_2.weight, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* span=aten::_convolution_2:0:0 */
}

But:

from copy import deepcopy
import tvm
from tvm import relay
from tvm.relay.quantize.quantize import _bind_params
optimize = tvm.transform.Sequential(
    [
        relay.transform.SimplifyInference(),
        relay.transform.FoldConstant(),
        relay.transform.FoldScaleAxis(),
        # relay.transform.CanonicalizeOps(),
        # relay.transform.FoldConstant(),
    ]
)
run_mod = deepcopy(mod)
run_mod["main"] = _bind_params(run_mod["main"], params)
with tvm.transform.PassContext(opt_level=3):
    # run_mod2 = relay.quantize.prerequisite_optimize(deepcopy(mod), params)
    run_mod = optimize(run_mod)

show:

fn (%x: Tensor[(1, 16, 64, 48), float32] /* ty=Tensor[(1, 16, 64, 48), float32] span=aten::adaptive_avg_pool2d_0.x:0:0 */) -> Tensor[(1, 64, 64, 48), float32] {
  %0 = nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* ty=Tensor[(1, 16, 1, 1), float32] span=aten::adaptive_avg_pool2d_0:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(8, 16, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten::_convolution_0:0:0 */;
  %2 = nn.bias_add(%1, meta[relay.Constant][2] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten::_convolution_0:0:0 */;
  %3 = nn.relu(%2) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten::relu__0:0:0 */;
  %4 = nn.conv2d(%3, meta[relay.Constant][3] /* ty=Tensor[(16, 8, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 1, 1), float32] span=aten::_convolution_1:0:0 */;
  %5 = nn.bias_add(%4, meta[relay.Constant][4] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16, 1, 1), float32] span=aten::_convolution_1:0:0 */;
  %6 = add(%5, 3f /* ty=float32 span=aten::hardsigmoid__0:0:0 */) /* ty=Tensor[(1, 16, 1, 1), float32] span=aten::hardsigmoid__0:0:0 */;
  %7 = clip(%6, a_min=0f, a_max=6f) /* ty=Tensor[(1, 16, 1, 1), float32] span=aten::hardsigmoid__0:0:0 */;
  %8 = divide(%7, 6f /* ty=float32 span=aten::hardsigmoid__0:0:0 */) /* ty=Tensor[(1, 16, 1, 1), float32] span=aten::hardsigmoid__0:0:0 */;
  %9 = squeeze(%8, axis=[0, 2, 3]) /* ty=Tensor[(16), float32] */;
  %10 = expand_dims(%9, axis=1, num_newaxis=2) /* ty=Tensor[(16, 1, 1), float32] */;
  %11 = multiply(meta[relay.Constant][0] /* ty=Tensor[(64, 16, 1, 1), float32] */, %10) /* ty=Tensor[(64, 16, 1, 1), float32] */;
  nn.conv2d(%x, %11, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 64, 48), float32] */
} /* ty=fn (Tensor[(1, 16, 64, 48), float32]) -> Tensor[(1, 64, 64, 48), float32] */

Everything’s a mess.