[BYOC][Quantization] Propagate channels-last PyTorch model to TVM without layer_transforms

Hi, I am currently attempting to ingest a quantized PyTorch model with NHWC format using TVM. However, I am not seeing the data layout for input or weight tensors in TVM. Looking closer at the pytorch frontend I observe the data_layout is not passed in.

Here is a mock example defined in PyTorch:

import torch 
from torch import nn
from torch import quantization
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.qconfig = quantization.default_qconfig
        self.quant = quantization.QuantStub()
        self.conv = nn.Conv2d(3, 32, kernel_size = 3, stride = 2, padding = (1, 1))
        self.dequant = quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        return self.dequant(x)

model = Net()
model = model.to(memory_format=torch.channels_last)
model = torch.quantization.prepare(model, inplace=False)
model = torch.quantization.convert(model, inplace=False)

This model produces channels_last output. Upon inspecting the weight tensor for the convolution I observe it is indeed channels-last as well. However, the Relay representation of this model does not recognize the layout as channels-last:

def @main(%input: Tensor[(1, 3, 224, 224), float32], param_device_types=[1], result_device_type=1, hash="14499cd1f5811f4e") -> Tensor[(1, 32, 112, 112), float32] {
  %10 = fn (%p01: Tensor[(1, 3, 224, 224), float32], Primitive=1, hash="c7e6893baed0a694") -> Tensor[(1, 3, 224, 224), int16] {
    %6 = round(%p01) /* ty=Tensor[(1, 3, 224, 224), float32] */;
    %7 = cast(%6, dtype="int32") /* ty=Tensor[(1, 3, 224, 224), int32] */;
    %8 = clip(%7, a_min=0f, a_max=255f) /* ty=Tensor[(1, 3, 224, 224), int32] */;
    %9 = cast(%8, dtype="uint8") /* ty=Tensor[(1, 3, 224, 224), uint8] */;
    cast(%9, dtype="int16") /* ty=Tensor[(1, 3, 224, 224), int16] */
  };
  %11 = %10(%input) /* ty=Tensor[(1, 3, 224, 224), int16] */;
  %12 = fn (%p0: Tensor[(1, 3, 224, 224), int16], %p1: Tensor[(32, 3, 3, 3), int16], %p2: Tensor[(32), int32], hash="6a14d785f792066f", data_layout="NCHW", kernel_layout="OIHW", Primitive=1, out_layout="") -> Tensor[(1, 32, 112, 112), float32] {
    %0 = nn.conv2d(%p0, %p1, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 32, 112, 112), int32] */;
    %1 = nn.bias_add(%0, %p2) /* ty=Tensor[(1, 32, 112, 112), int32] */;
    %2 = fixed_point_multiply(%1, multiplier=1658183424, shift=-9) /* ty=Tensor[(1, 32, 112, 112), int32] */;
    %3 = clip(%2, a_min=0f, a_max=255f) /* ty=Tensor[(1, 32, 112, 112), int32] */;
    %4 = cast(%3, dtype="uint8") /* ty=Tensor[(1, 32, 112, 112), uint8] */;
    %5 = cast(%4, dtype="int32") /* ty=Tensor[(1, 32, 112, 112), int32] */;
    cast(%5, dtype="float32") /* ty=Tensor[(1, 32, 112, 112), float32] */
  };
  %12(%11, meta[relay.Constant][0] /* ty=Tensor[(32, 3, 3, 3), int16] */, meta[relay.Constant][1] /* ty=Tensor[(32), int32] */) /* ty=Tensor[(1, 32, 112, 112), float32] */
}

What is the path of least resistance to propagate the data layout to Relay? A ConvertLayout Relay pass to NHWC by @anijain2305 would be redundant and insert layout_transform ops.

Right, NHWC input layout for PT frontend is not supported. It would indeed be a nice feature to have and I actually worked exactly on that before (in a private repo, unfortunately).

This is something I’ve also been looking at.

It seems that PyTorch’s ONNX export doesn’t respect model.to(memory_format=torch.channels_last), so exports to NCHW also.

I was looking at the layout transformation passes in the TVM side, which I think would be done offline if the whole model is converted, and the appropriate optimization passes are applied after the conversion.

However, as far as I can see you need to be explicit about which ops you want to transform the layout of. As this post discusses, and as the convert layout docs show, we need to be explicit about nn.conv2d, nn.dense, etc.

As far as I can see there isn’t a single convert_layout(mod, "NHWC") function which does “best effort” conversion of all ops in a model.