Run pytorch QAT quantized model on TVM

I have played around with the Pytorch QAT and the quantizedmodel works correctly in PyTorch. Is there a sample guide me to do inference of this quantized model on TVM? Thanks.

1 Like

Yes we have a tutorial on importing a quantized PyTorch model from torchvision. https://tvm.apache.org/docs/tutorials/frontend/deploy_prequantized.html#sphx-glr-tutorials-frontend-deploy-prequantized-py

The tutorial uses the torchvision quantized mobilenet v2, which was QAT-ed. Other quantized models from torchvision are not QAT-ed (post-training quantization). I tested many of them and they should import to TVM without issues. Let me know if your QAT model can import to TVM.

@masahi Thanks a lot for the pointer. I could run the mobinet v2 example model on TVM correct and when I try my own model by following the tutorial the results are not correct. Is there a way to debug this issue? Following is snippet of code:

loaded = torch.jit.load('traced_qat_model.pt')
out_data = loaded(input_tensor)
print("from_trace:", out_data) ##################This is correct

input_name = "input"  
input_shapes = [(input_name, (1, 3, 384, 384))]
mod, params = relay.frontend.from_pytorch(loaded, input_shapes)
tvm_result, rt_mod = run_tvm_model(mod, params, input_name, input_tensor, target="llvm")
print("from_tvm:", tvm_result) ##################This is incorrect

Also I notice that when execute the relay.frontend.from_pytorch() got a lot of this wornings: WARNING:root:Untyped Tensor found, assume it is float32

By “incorrect”, do you mean the result is a garbage?

The warning is probably due to the fact that you serialized the model after QAT. PyTorch jit erases all dtype information when it serializes the jitted model. So after loading it back, there is no dtype information anymore and everything defaults to float32. See this issue for more details. https://github.com/pytorch/pytorch/issues/39690

One thing you can try is, you serialize your original float32 model that is QAT-ed before running PT the quantization flow. After you torch.jit.load(...) your float32 model, you run the quantization pass, and import it to TVM. That way, TVM can see int8 dtype. I believe this is how the QAT-ed, quantized mobilenet v2 model from torchvision makes it possible to work with TVM.

By “incorrect”, it means the output values are totally different from the FP32. For this segmentation model, there is no mask on the output.

Here is what I chaned accordling. Please let me know if I did it correctly.

loaded = torch.jit.load('traced_QAT_fp32.pt')
loaded.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
loaded = torch.quantization.prepare_qat(loaded)
loaded = torch.quantization.convert(loaded)
out_data = loaded(input_tensor)
print("from_trace:", out_data)  ############## output is correct

input_name = "input"  # the input name can be be arbitrary for PyTorch frontend.
input_shapes = [(input_name, (1, 3, 384, 384))]
mod, params = relay.frontend.from_pytorch(loaded, input_shapes)
tvm_result, rt_mod = run_tvm_model(mod, params, input_name, input_tensor, target="llvm")
print("from_tvm:", tvm_result)  ############# output from tvm is not correct

When run relay.frontend.from_pytorch() it gives the following error:

   File "run_depth_pytorch_int8_aqt_model_simple.py", line 275, in main
    mod, params = relay.frontend.from_pytorch(loaded, input_shapes)
  File "C:\Users\bigtree\Documents\installation\TVM\tvm\python\tvm\relay\frontend\pytorch.py", line 3561, 
   in from_pytorch
    converter.report_missing_conversion(op_names)
  File "C:\Users\bigtree\Documents\installation\TVM\tvm\python\tvm\relay\frontend\pytorch.py", line 2821, 
  in report_missing_conversion
    raise NotImplementedError(msg)  
NotImplementedError: The following operators are not implemented: ['aten::fake_quantize_per_channel_affine', 'aten::fake_quantize_per_tensor_affine']

Interesting that you now see NotImplementedError. I believe a correctly quantized model shouldn’t have fake_quantize ops (they are used during QAT, but should be converted to real quantized ops). Can you dig deeper on the conversion process? I’m not familiar with the PT QAT workflow, unfortunately.

For example, this is how torchvision quantizes QAT-ed mobilenet v3.

TVM should be able to load this model (see https://github.com/Edgecortix-Inc/pytorch_quantization/pull/16). You can compare it with your model / flow and see what is missing.

I think I followed these steps. The previous " fake_quantize ops " happens if I read a pre-saved trace/jit file. After I followed the above step, the model could run on tvm but still gives the wrong result.

My model is efficientNet based. Do you know is there is any success run the QAT model of this backbone on TVM?

After print the mod and parameters from mod, params = relay.frontend.from_pytorch(loaded, input_shapes)

I found the params’ dtype is float32. This one should be uint8, right?

   [[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
     [ 0.00000000e+00, -4.88758087e-06,  0.00000000e+00],
     [ 0.00000000e+00,  9.77516174e-06,  0.00000000e+00]]],


   [[[-2.19370108e-02, -9.50603783e-02,  1.90120757e-01],
     [ 1.31622061e-01, -9.35979128e-01, -1.46246739e-02],
     [ 7.31233656e-02, -5.84986955e-02,  3.65616828e-02]]]],
  dtype=float32), 'pretrained.layer1.3.0.conv_dw_bias': <tvm.nd.NDArray shape=(32,), cpu(0)>
array([ 4.1551757e+00,  7.0967369e+00, -1.3823882e+00,  6.2057533e+00,
        7.3633879e-02, -7.4240267e-03,  2.6812696e-01,  7.7180414e+00,
        7.4075651e-01,  6.4278054e+00,  2.8354998e+00,  4.5981092e+00,
       -1.1487323e+00, -1.6544257e-01, -8.3797729e-01,  4.2483969e+00,
        1.1879738e+01,  4.4961596e-01,  6.6468501e+00,  5.3413010e+00,
        5.6053448e-01, -6.0282320e-01,  4.9038339e+00,  3.5531816e+00,
       -8.5106981e-01,  4.2608976e-02,  4.0155058e+00, -5.6184396e-02,
        8.7596375e-01,  2.8182871e+00, -2.1643894e+00,  4.8622866e+00],
      dtype=float32), 'pretrained.layer1.3.0.conv_pw_weight': <tvm.nd.NDArray shape=(16, 32, 1, 1), cpu(0)>
array([[[[-0.8916866 ]],

No, having float32 parameters after import is expected. Since PyTorch stores quantized tensors in a custom format that only PT understands, to extract 8 bit weight we have to first “unpack” the custom quantized tensor into float32, convert it to numpy and then back to int8 using a relay op.

The conversion of weights back to int8 happens during relay.build(...). To see this, you can replace

    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)

with

    with tvm.transform.PassContext(opt_level=3):
        opt_mod, opt_params = relay.optimize(mod, target=target, params=params)

and opt_params should be int8.

I’m not aware of PT quantized efficient net, but if you have such a model, it should be possible to import it to TVM. Otherwise I consider it a bug. I definitely want to see quantized efficientnet running on TVM!

 ## genrate executable with trace
    script_module = torch.jit.trace(model_int8, input_tensor).eval()
    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
    with tvm.transform.PassContext(opt_level=3):
        opt_mod, opt_params = relay.optimize(mod, target=target, params=params)
    print("opt_params:", opt_params)

The output is empyty : opt_params: {} Please let me know if I missed something?

Ah right. This means all weights are embedded in opt_mod as constants. You can print opt_mod to see that the weight are all int8, but you don’t need to worry about it.

Here is the output from of the opt_mod. The weight type seems to use a mixed-precision…

 type tensor_int8_t {
  tensor_nil_int8,
  tensor0_int8(int8),
  tensor1_int8(Tensor[(?), int8]),
  tensor2_int8(Tensor[(?, ?), int8]),
  tensor3_int8(Tensor[(?, ?, ?), int8]),
  tensor4_int8(Tensor[(?, ?, ?, ?), int8]),
  tensor5_int8(Tensor[(?, ?, ?, ?, ?), int8]),
  tensor6_int8(Tensor[(?, ?, ?, ?, ?, ?), int8]),
}

type tensor_uint8_t {
  tensor_nil_uint8,
  tensor0_uint8(uint8),
  tensor1_uint8(Tensor[(?), uint8]),
  tensor2_uint8(Tensor[(?, ?), uint8]),
  tensor3_uint8(Tensor[(?, ?, ?), uint8]),
  tensor4_uint8(Tensor[(?, ?, ?, ?), uint8]),
  tensor5_uint8(Tensor[(?, ?, ?, ?, ?), uint8]),
  tensor6_uint8(Tensor[(?, ?, ?, ?, ?, ?), uint8]),
}

type tensor_float16_t {
  tensor_nil_float16,
  tensor0_float16(float16),
  tensor1_float16(Tensor[(?), float16]),
  tensor2_float16(Tensor[(?, ?), float16]),
  tensor3_float16(Tensor[(?, ?, ?), float16]),
  tensor4_float16(Tensor[(?, ?, ?, ?), float16]),
  tensor5_float16(Tensor[(?, ?, ?, ?, ?), float16]),
  tensor6_float16(Tensor[(?, ?, ?, ?, ?, ?), float16]),
}

type Tree[A] {
  Rose(A, List[Tree[A]]),
}

type tensor_int16_t {
  tensor_nil_int16,
  tensor0_int16(int16),
  tensor1_int16(Tensor[(?), int16]),
  tensor2_int16(Tensor[(?, ?), int16]),
  tensor3_int16(Tensor[(?, ?, ?), int16]),
  tensor4_int16(Tensor[(?, ?, ?, ?), int16]),
  tensor5_int16(Tensor[(?, ?, ?, ?, ?), int16]),
  tensor6_int16(Tensor[(?, ?, ?, ?, ?, ?), int16]),
}

type tensor_uint16_t {
  tensor_nil_uint16,
  tensor0_uint16(uint16),
  tensor1_uint16(Tensor[(?), uint16]),
  tensor2_uint16(Tensor[(?, ?), uint16]),
  tensor3_uint16(Tensor[(?, ?, ?), uint16]),
  tensor4_uint16(Tensor[(?, ?, ?, ?), uint16]),
  tensor5_uint16(Tensor[(?, ?, ?, ?, ?), uint16]),
  tensor6_uint16(Tensor[(?, ?, ?, ?, ?, ?), uint16]),
}

type tensor_int32_t {
  tensor_nil_int32,
  tensor0_int32(int32),
  tensor1_int32(Tensor[(?), int32]),
  tensor2_int32(Tensor[(?, ?), int32]),
  tensor3_int32(Tensor[(?, ?, ?), int32]),
  tensor4_int32(Tensor[(?, ?, ?, ?), int32]),
  tensor5_int32(Tensor[(?, ?, ?, ?, ?), int32]),
  tensor6_int32(Tensor[(?, ?, ?, ?, ?, ?), int32]),
}

type List[A] {
  Cons(A, List[A]),
  Nil,
}

type Option[A] {
  Some(A),
  None,
}

type tensor_float32_t {
  tensor_nil_float32,
  tensor0_float32(float32),
  tensor1_float32(Tensor[(?), float32]),
  tensor2_float32(Tensor[(?, ?), float32]),
  tensor3_float32(Tensor[(?, ?, ?), float32]),
  tensor4_float32(Tensor[(?, ?, ?, ?), float32]),
  tensor5_float32(Tensor[(?, ?, ?, ?, ?), float32]),
  tensor6_float32(Tensor[(?, ?, ?, ?, ?, ?), float32]),
}

type tensor_int64_t {
  tensor_nil_int64,
  tensor0_int64(int64),
  tensor1_int64(Tensor[(?), int64]),
  tensor2_int64(Tensor[(?, ?), int64]),
  tensor3_int64(Tensor[(?, ?, ?), int64]),
  tensor4_int64(Tensor[(?, ?, ?, ?), int64]),
  tensor5_int64(Tensor[(?, ?, ?, ?, ?), int64]),
  tensor6_int64(Tensor[(?, ?, ?, ?, ?, ?), int64]),
}

type tensor_float64_t {
  tensor_nil_float64,
  tensor0_float64(float64),
  tensor1_float64(Tensor[(?), float64]),
  tensor2_float64(Tensor[(?, ?), float64]),
  tensor3_float64(Tensor[(?, ?, ?), float64]),
  tensor4_float64(Tensor[(?, ?, ?, ?), float64]),
  tensor5_float64(Tensor[(?, ?, ?, ?, ?), float64]),
  tensor6_float64(Tensor[(?, ?, ?, ?, ?, ?), float64]),
}

def @main(%input: Tensor[(1, 3, 256, 256), float32], hash="95d81b5910b70cc4") -> Tensor[(1, 1, 256, 256), float32] {
  %1074 = fn (%p0134: Tensor[(1, 3, 256, 256), float32], %p1105: int16, Primitive=1, hash="34f16c31ee5dbe5e", src_layout="NCHW", dst_layout="NCHW3c") -> Tensor[(1, 1, 258, 258, 3), int16] {
    %1065 = divide(%p0134, 0.0374636f /* ty=float32 */) /* ty=Tensor[(1, 3, 256, 256), float32] */;
    %1066 = add(%1065, 57f /* ty=float32 */) /* ty=Tensor[(1, 3, 256, 256), float32] */;
    %1067 = round(%1066) /* ty=Tensor[(1, 3, 256, 256), float32] */;
    %1068 = cast(%1067, dtype="int32") /* ty=Tensor[(1, 3, 256, 256), int32] */;
    %1069 = clip(%1068, a_min=0f, a_max=255f) /* ty=Tensor[(1, 3, 256, 256), int32] */;
    %1070 = cast(%1069, dtype="uint8") /* ty=Tensor[(1, 3, 256, 256), uint8] */;
    %1071 = nn.pad(%1070, 57f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 3, 258, 258), uint8] */;
    %1072 = cast(%1071, dtype="int16") /* ty=Tensor[(1, 3, 258, 258), int16] */;
    %1073 = subtract(%1072, %p1105) /* ty=Tensor[(1, 3, 258, 258), int16] */;
    layout_transform(%1073, src_layout="NCHW", dst_layout="NCHW3c") /* ty=Tensor[(1, 1, 258, 258, 3), int16] */
  };
  %1075 = %1074(%input, meta[relay.Constant][0] /* ty=int16 */) /* ty=Tensor[(1, 1, 258, 258, 3), int16] */;
  %1076 = fn (%p0133: Tensor[(1, 1, 258, 258, 3), int16], %p1104: Tensor[(4, 1, 3, 3, 3, 8), int16], %p280: Tensor[(1, 4, 1, 1, 8), int32], %p378: Tensor[(1, 4, 1, 1, 8), int64], %p478: Tensor[(1, 4, 1, 1, 8), int64], %p577: Tensor[(1, 4, 1, 1, 8), int64], kernel_layout="OIHW3i8o", Primitive=1, out_layout="NCHW8c", hash="90c700876890682c", data_layout="NCHW3c") -> Tensor[(1, 4, 128, 128, 8), uint8] {
    %1057 = nn.contrib_conv2d_NCHWc(%p0133, %p1104, strides=[2, 2], padding=[0, 0, 0, 0], channels=32, kernel_size=[3, 3], data_layout="NCHW3c", kernel_layout="OIHW3i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    %1058 = add(%1057, %p280) /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    %1059 = cast(%1058, dtype="int64") /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1060 = multiply(%1059, %p378) /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1061 = add(%1060, %p478) /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1062 = right_shift(%1061, %p577) /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1063 = cast(%1062, dtype="int32") /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    %1064 = clip(%1063, a_min=0f, a_max=255f) /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    cast(%1064, dtype="uint8") /* ty=Tensor[(1, 4, 128, 128, 8), uint8] */
  };
  %1077 = %1076(%1075, meta[relay.Constant][1] /* ty=Tensor[(4, 1, 3, 3, 3, 8), int16] */, meta[relay.Constant][2] /* ty=Tensor[(1, 4, 1, 1, 8), int32] */, meta[relay.Constant][3] /* ty=Tensor[(1, 4, 1, 1, 8), int64] */, meta[relay.Constant][4] /* ty=Tensor[(1, 4, 1, 1, 8), int64] */, meta[relay.Constant][5] /* ty=Tensor[(1, 4, 1, 1, 8), int64] */) /* ty=Tensor[(1, 4, 128, 128, 8), uint8] */;
  %1078 = fn (%p0132: Tensor[(1, 4, 128, 128, 8), uint8], Primitive=1, hash="58cb412d65bf60ed") -> Tensor[(1, 4, 130, 130, 8), int16] {
    %1056 = nn.pad(%p0132, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]) /* ty=Tensor[(1, 4, 130, 130, 8), uint8] */;
    cast(%1056, dtype="int16") /* ty=Tensor[(1, 4, 130, 130, 8), int16] */
  };
  %1079 = %1078(%1077) /* ty=Tensor[(1, 4, 130, 130, 8), int16] */;
  %1080 = fn (%p0131: Tensor[(1, 4, 130, 130, 8), int16], %p1103: Tensor[(4, 1, 3, 3, 1, 8), int16], %p279: Tensor[(1, 4, 1, 1, 8), int32], %p377: Tensor[(1, 4, 1, 1, 8), int64], %p477: Tensor[(1, 4, 1, 1, 8), int64], %p576: Tensor[(1, 4, 1, 1, 8), int64], kernel_layout="OIHW1i8o", Primitive=1, out_layout="NCHW8c", hash="49e7c1b46488c986", data_layout="NCHW8c") -> Tensor[(1, 4, 128, 128, 8), int16] {
    %1046 = nn.contrib_depthwise_conv2d_NCHWc(%p0131, %p1103, padding=[0, 0, 0, 0], groups=32, channels=32, kernel_size=[3, 3], data_layout="NCHW8c", kernel_layout="OIHW1i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    %1047 = add(%1046, %p279) /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    %1048 = cast(%1047, dtype="int64") /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1049 = multiply(%1048, %p377) /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1050 = add(%1049, %p477) /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1051 = right_shift(%1050, %p576) /* ty=Tensor[(1, 4, 128, 128, 8), int64] */;
    %1052 = cast(%1051, dtype="int32") /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    %1053 = clip(%1052, a_min=0f, a_max=255f) /* ty=Tensor[(1, 4, 128, 128, 8), int32] */;
    %1054 = cast(%1053, dtype="uint8") /* ty=Tensor[(1, 4, 128, 128, 8), uint8] */;
    %1055 = clip(%1054, a_min=0f, a_max=6f) /* ty=Tensor[(1, 4, 128, 128, 8), uint8] */;
    cast(%1055, dtype="int16") /* ty=Tensor[(1, 4, 128, 128, 8), int16] */
  };
  %1081 = %1080(%1079, meta[relay.Constant][6] /* ty=Tensor[(4, 1, 3, 3, 1, 8), int16] */, meta[relay.Constant][7] /* ty=Tensor[(1, 4, 1, 1, 8), int32] */, meta[relay.Constant][8] /* ty=Tensor[(1, 4, 1, 1, 8), int64] */, meta[relay.Constant][9] /* ty=Tensor[(1, 4, 1, 1, 8), int64] */, meta[relay.Constant][10] /* ty=Tensor[(1, 4, 1, 1, 8), int64] */) /* ty=Tensor[(1, 4, 128, 128, 8), int16] */;
  %1082 = fn (%p0130: Tensor[(1, 4, 128, 128, 8), int16], %p1102: Tensor[(2, 4, 1, 1, 8, 8), int16], %p278: Tensor[(1, 2, 1, 1, 8), int32], %p376: Tensor[(1, 2, 1, 1, 8), int64], %p476: Tensor[(1, 2, 1, 1, 8), int64], %p575: Tensor[(1, 2, 1, 1, 8), int64], kernel_layout="OIHW8i8o", Primitive=1, out_layout="NCHW8c", hash="18a04f1b56a3790c", data_layout="NCHW8c") -> Tensor[(1, 2, 128, 128, 8), int16] {
    %1036 = nn.contrib_conv2d_NCHWc(%p0130, %p1102, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 2, 128, 128, 8), int32] */;
    %1037 = add(%1036, %p278) /* ty=Tensor[(1, 2, 128, 128, 8), int32] */;
    %1038 = cast(%1037, dtype="int64") /* ty=Tensor[(1, 2, 128, 128, 8), int64] */;
    %1039 = multiply(%1038, %p376) /* ty=Tensor[(1, 2, 128, 128, 8), int64] */;
    %1040 = add(%1039, %p476) /* ty=Tensor[(1, 2, 128, 128, 8), int64] */;
    %1041 = right_shift(%1040, %p575) /* ty=Tensor[(1, 2, 128, 128, 8), int64] */;
    %1042 = cast(%1041, dtype="int32") /* ty=Tensor[(1, 2, 128, 128, 8), int32] */;
    %1043 = clip(%1042, a_min=0f, a_max=255f) /* ty=Tensor[(1, 2, 128, 128, 8), int32] */;
    %1044 = cast(%1043, dtype="uint8") /* ty=Tensor[(1, 2, 128, 128, 8), uint8] */;
    %1045 = clip(%1044, a_min=0f, a_max=6f) /* ty=Tensor[(1, 2, 128, 128, 8), uint8] */;
    cast(%1045, dtype="int16") /* ty=Tensor[(1, 2, 128, 128, 8), int16] */
  };
  %1083 = %1082(%1081, meta[relay.Constant][11] /* ty=Tensor[(2, 4, 1, 1, 8, 8), int16] */, meta[relay.Constant][12] /* ty=Tensor[(1, 2, 1, 1, 8), int32] */, meta[relay.Constant][13] /* ty=Tensor[(1, 2, 1, 1, 8), int64] */, meta[relay.Constant][14] /* ty=Tensor[(1, 2, 1, 1, 8), int64] */, meta[relay.Constant][15] /* ty=Tensor[(1, 2, 1, 1, 8), int64] */) /* ty=Tensor[(1, 2, 128, 128, 8), int16] */;
  %1084 = fn (%p0129: Tensor[(1, 2, 128, 128, 8), int16], %p1101: Tensor[(12, 2, 1, 1, 8, 8), int16], %p277: Tensor[(1, 12, 1, 1, 8), int32], %p375: Tensor[(1, 12, 1, 1, 8), int64], %p475: Tensor[(1, 12, 1, 1, 8), int64], %p574: Tensor[(1, 12, 1, 1, 8), int64], kernel_layout="OIHW8i8o", Primitive=1, out_layout="NCHW8c", hash="21bb327e8e3334e4", data_layout="NCHW8c") -> Tensor[(1, 12, 128, 128, 8), uint8] {
    %1027 = nn.contrib_conv2d_NCHWc(%p0129, %p1101, padding=[0, 0, 0, 0], channels=96, kernel_size=[1, 1], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 12, 128, 128, 8), int32] */;
    %1028 = add(%1027, %p277) /* ty=Tensor[(1, 12, 128, 128, 8), int32] */;
    %1029 = cast(%1028, dtype="int64") /* ty=Tensor[(1, 12, 128, 128, 8), int64] */;
    %1030 = multiply(%1029, %p375) /* ty=Tensor[(1, 12, 128, 128, 8), int64] */;
    %1031 = add(%1030, %p475) /* ty=Tensor[(1, 12, 128, 128, 8), int64] */;
    %1032 = right_shift(%1031, %p574) /* ty=Tensor[(1, 12, 128, 128, 8), int64] */;
    %1033 = cast(%1032, dtype="int32") /* ty=Tensor[(1, 12, 128, 128, 8), int32] */;
    %1034 = clip(%1033, a_min=0f, a_max=255f) /* ty=Tensor[(1, 12, 128, 128, 8), int32] */;
    %1035 = cast(%1034, dtype="uint8") /* ty=Tensor[(1, 12, 128, 128, 8), uint8] */;
    clip(%1035, a_min=0f, a_max=6f) /* ty=Tensor[(1, 12, 128, 128, 8), uint8] */
  };
  %1085 = %1084(%1083, meta[relay.Constant][16] /* ty=Tensor[(12, 2, 1, 1, 8, 8), int16] */, meta[relay.Constant][17] /* ty=Tensor[(1, 12, 1, 1, 8), int32] */, meta[relay.Constant][18] /* ty=Tensor[(1, 12, 1, 1, 8), int64] */, meta[relay.Constant][19] /* ty=Tensor[(1, 12, 1, 1, 8), int64] */, meta[relay.Constant][20] /* ty=Tensor[(1, 12, 1, 1, 8), int64] */) /* ty=Tensor[(1, 12, 128, 128, 8), uint8] */;
  %1086 = fn (%p0128: Tensor[(1, 12, 128, 128, 8), uint8], Primitive=1, hash="9b9b85ef238a2748") -> Tensor[(1, 12, 130, 130, 8), int16] {
    %1026 = nn.pad(%p0128, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]) /* ty=Tensor[(1, 12, 130, 130, 8), uint8] */;
    cast(%1026, dtype="int16") /* ty=Tensor[(1, 12, 130, 130, 8), int16] */
  };
  %1087 = %1086(%1085) /* ty=Tensor[(1, 12, 130, 130, 8), int16] */;
  %1088 = fn (%p0127: Tensor[(1, 12, 130, 130, 8), int16], %p1100: Tensor[(12, 1, 3, 3, 1, 8), int16], %p276: Tensor[(1, 12, 1, 1, 8), int32], %p374: Tensor[(1, 12, 1, 1, 8), int64], %p474: Tensor[(1, 12, 1, 1, 8), int64], %p573: Tensor[(1, 12, 1, 1, 8), int64], kernel_layout="OIHW1i8o", Primitive=1, out_layout="NCHW8c", hash="23b7c01e92ee2d1b", data_layout="NCHW8c") -> Tensor[(1, 12, 64, 64, 8), int16] {
    %1016 = nn.contrib_depthwise_conv2d_NCHWc(%p0127, %p1100, strides=[2, 2], padding=[0, 0, 0, 0], groups=96, channels=96, kernel_size=[3, 3], data_layout="NCHW8c", kernel_layout="OIHW1i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 12, 64, 64, 8), int32] */;
    %1017 = add(%1016, %p276) /* ty=Tensor[(1, 12, 64, 64, 8), int32] */;
    %1018 = cast(%1017, dtype="int64") /* ty=Tensor[(1, 12, 64, 64, 8), int64] */;
    %1019 = multiply(%1018, %p374) /* ty=Tensor[(1, 12, 64, 64, 8), int64] */;
    %1020 = add(%1019, %p474) /* ty=Tensor[(1, 12, 64, 64, 8), int64] */;
    %1021 = right_shift(%1020, %p573) /* ty=Tensor[(1, 12, 64, 64, 8), int64] */;
    %1022 = cast(%1021, dtype="int32") /* ty=Tensor[(1, 12, 64, 64, 8), int32] */;
    %1023 = clip(%1022, a_min=0f, a_max=255f) /* ty=Tensor[(1, 12, 64, 64, 8), int32] */;
    %1024 = cast(%1023, dtype="uint8") /* ty=Tensor[(1, 12, 64, 64, 8), uint8] */;
    %1025 = clip(%1024, a_min=0f, a_max=6f) /* ty=Tensor[(1, 12, 64, 64, 8), uint8] */;
    cast(%1025, dtype="int16") /* ty=Tensor[(1, 12, 64, 64, 8), int16] */
  };
  %1089 = %1088(%1087, meta[relay.Constant][21] /* ty=Tensor[(12, 1, 3, 3, 1, 8), int16] */, meta[relay.Constant][22] /* ty=Tensor[(1, 12, 1, 1, 8), int32] */, meta[relay.Constant][23] /* ty=Tensor[(1, 12, 1, 1, 8), int64] */, meta[relay.Constant][24] /* ty=Tensor[(1, 12, 1, 1, 8), int64] */, meta[relay.Constant][25] /* ty=Tensor[(1, 12, 1, 1, 8), int64] */) /* ty=Tensor[(1, 12, 64, 64, 8), int16] */;
  %1090 = fn (%p0126: Tensor[(1, 12, 64, 64, 8), int16], %p199: Tensor[(3, 12, 1, 1, 8, 8), int16], %p275: Tensor[(1, 3, 1, 1, 8), int32], %p373: Tensor[(1, 3, 1, 1, 8), int64], %p473: Tensor[(1, 3, 1, 1, 8), int64], %p572: Tensor[(1, 3, 1, 1, 8), int64], kernel_layout="OIHW8i8o", Primitive=1, out_layout="NCHW8c", hash="ab08e1ad75eb7d95", data_layout="NCHW8c") -> Tensor[(1, 3, 64, 64, 8), uint8] {
    %1007 = nn.contrib_conv2d_NCHWc(%p0126, %p199, padding=[0, 0, 0, 0], channels=24, kernel_size=[1, 1], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %1008 = add(%1007, %p275) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %1009 = cast(%1008, dtype="int64") /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %1010 = multiply(%1009, %p373) /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %1011 = add(%1010, %p473) /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %1012 = right_shift(%1011, %p572) /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %1013 = cast(%1012, dtype="int32") /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %1014 = add(64 /* ty=int32 */, %1013) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %1015 = clip(%1014, a_min=0f, a_max=255f) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    cast(%1015, dtype="uint8") /* ty=Tensor[(1, 3, 64, 64, 8), uint8] */
  };
  %1091 = %1090(%1089, meta[relay.Constant][26] /* ty=Tensor[(3, 12, 1, 1, 8, 8), int16] */, meta[relay.Constant][27] /* ty=Tensor[(1, 3, 1, 1, 8), int32] */, meta[relay.Constant][28] /* ty=Tensor[(1, 3, 1, 1, 8), int64] */, meta[relay.Constant][29] /* ty=Tensor[(1, 3, 1, 1, 8), int64] */, meta[relay.Constant][30] /* ty=Tensor[(1, 3, 1, 1, 8), int64] */) /* ty=Tensor[(1, 3, 64, 64, 8), uint8] */;
  %1092 = fn (%p0125: Tensor[(1, 3, 64, 64, 8), uint8], %p198: int16, Primitive=1, hash="e1510e19549852ba") -> Tensor[(1, 3, 64, 64, 8), int16] {
    %1006 = cast(%p0125, dtype="int16") /* ty=Tensor[(1, 3, 64, 64, 8), int16] */;
    subtract(%1006, %p198) /* ty=Tensor[(1, 3, 64, 64, 8), int16] */
  };
  %1093 = %1092(%1091, meta[relay.Constant][31] /* ty=int16 */) /* ty=Tensor[(1, 3, 64, 64, 8), int16] */;
  %1094 = fn (%p0124: Tensor[(1, 3, 64, 64, 8), int16], %p197: Tensor[(18, 3, 1, 1, 8, 8), int16], %p274: Tensor[(1, 18, 1, 1, 8), int32], %p372: Tensor[(1, 18, 1, 1, 8), int64], %p472: Tensor[(1, 18, 1, 1, 8), int64], %p571: Tensor[(1, 18, 1, 1, 8), int64], kernel_layout="OIHW8i8o", Primitive=1, out_layout="NCHW8c", hash="7bf1388e46e6a4eb", data_layout="NCHW8c") -> Tensor[(1, 18, 64, 64, 8), uint8] {
    %997 = nn.contrib_conv2d_NCHWc(%p0124, %p197, padding=[0, 0, 0, 0], channels=144, kernel_size=[1, 1], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %998 = add(%997, %p274) /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %999 = cast(%998, dtype="int64") /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %1000 = multiply(%999, %p372) /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %1001 = add(%1000, %p472) /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %1002 = right_shift(%1001, %p571) /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %1003 = cast(%1002, dtype="int32") /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %1004 = clip(%1003, a_min=0f, a_max=255f) /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %1005 = cast(%1004, dtype="uint8") /* ty=Tensor[(1, 18, 64, 64, 8), uint8] */;
    clip(%1005, a_min=0f, a_max=6f) /* ty=Tensor[(1, 18, 64, 64, 8), uint8] */
  };
  %1095 = %1094(%1093, meta[relay.Constant][32] /* ty=Tensor[(18, 3, 1, 1, 8, 8), int16] */, meta[relay.Constant][33] /* ty=Tensor[(1, 18, 1, 1, 8), int32] */, meta[relay.Constant][34] /* ty=Tensor[(1, 18, 1, 1, 8), int64] */, meta[relay.Constant][35] /* ty=Tensor[(1, 18, 1, 1, 8), int64] */, meta[relay.Constant][36] /* ty=Tensor[(1, 18, 1, 1, 8), int64] */) /* ty=Tensor[(1, 18, 64, 64, 8), uint8] */;
  %1096 = fn (%p0123: Tensor[(1, 18, 64, 64, 8), uint8], Primitive=1, hash="8667a3517900c8d6") -> Tensor[(1, 18, 66, 66, 8), int16] {
    %996 = nn.pad(%p0123, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]) /* ty=Tensor[(1, 18, 66, 66, 8), uint8] */;
    cast(%996, dtype="int16") /* ty=Tensor[(1, 18, 66, 66, 8), int16] */
  };
  %1097 = %1096(%1095) /* ty=Tensor[(1, 18, 66, 66, 8), int16] */;
  %1098 = fn (%p0122: Tensor[(1, 18, 66, 66, 8), int16], %p196: Tensor[(18, 1, 3, 3, 1, 8), int16], %p273: Tensor[(1, 18, 1, 1, 8), int32], %p371: Tensor[(1, 18, 1, 1, 8), int64], %p471: Tensor[(1, 18, 1, 1, 8), int64], %p570: Tensor[(1, 18, 1, 1, 8), int64], kernel_layout="OIHW1i8o", Primitive=1, out_layout="NCHW8c", hash="b730ef33ea4b22ab", data_layout="NCHW8c") -> Tensor[(1, 18, 64, 64, 8), int16] {
    %986 = nn.contrib_depthwise_conv2d_NCHWc(%p0122, %p196, padding=[0, 0, 0, 0], groups=144, channels=144, kernel_size=[3, 3], data_layout="NCHW8c", kernel_layout="OIHW1i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %987 = add(%986, %p273) /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %988 = cast(%987, dtype="int64") /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %989 = multiply(%988, %p371) /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %990 = add(%989, %p471) /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %991 = right_shift(%990, %p570) /* ty=Tensor[(1, 18, 64, 64, 8), int64] */;
    %992 = cast(%991, dtype="int32") /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %993 = clip(%992, a_min=0f, a_max=255f) /* ty=Tensor[(1, 18, 64, 64, 8), int32] */;
    %994 = cast(%993, dtype="uint8") /* ty=Tensor[(1, 18, 64, 64, 8), uint8] */;
    %995 = clip(%994, a_min=0f, a_max=6f) /* ty=Tensor[(1, 18, 64, 64, 8), uint8] */;
    cast(%995, dtype="int16") /* ty=Tensor[(1, 18, 64, 64, 8), int16] */
  };
  %1099 = %1098(%1097, meta[relay.Constant][37] /* ty=Tensor[(18, 1, 3, 3, 1, 8), int16] */, meta[relay.Constant][38] /* ty=Tensor[(1, 18, 1, 1, 8), int32] */, meta[relay.Constant][39] /* ty=Tensor[(1, 18, 1, 1, 8), int64] */, meta[relay.Constant][40] /* ty=Tensor[(1, 18, 1, 1, 8), int64] */, meta[relay.Constant][41] /* ty=Tensor[(1, 18, 1, 1, 8), int64] */) /* ty=Tensor[(1, 18, 64, 64, 8), int16] */;
  %1100 = fn (%p0121: Tensor[(1, 18, 64, 64, 8), int16], %p195: Tensor[(3, 18, 1, 1, 8, 8), int16], %p272: Tensor[(1, 3, 1, 1, 8), int32], %p370: Tensor[(1, 3, 1, 1, 8), int64], %p470: Tensor[(1, 3, 1, 1, 8), int64], %p569: Tensor[(1, 3, 1, 1, 8), int64], %p621: Tensor[(1, 3, 64, 64, 8), uint8], kernel_layout="OIHW8i8o", Primitive=1, out_layout="NCHW8c", hash="5fb2140ee53d7af7", data_layout="NCHW8c") -> Tensor[(1, 3, 64, 64, 8), uint8] {
    %965 = nn.contrib_conv2d_NCHWc(%p0121, %p195, padding=[0, 0, 0, 0], channels=24, kernel_size=[1, 1], data_layout="NCHW8c", kernel_layout="OIHW8i8o", out_layout="NCHW8c", out_dtype="int32") /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %966 = add(%965, %p272) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %967 = cast(%966, dtype="int64") /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %968 = multiply(%967, %p370) /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %969 = add(%968, %p470) /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %970 = right_shift(%969, %p569) /* ty=Tensor[(1, 3, 64, 64, 8), int64] */;
    %971 = cast(%970, dtype="int32") /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %972 = add(63 /* ty=int32 */, %971) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %973 = clip(%972, a_min=0f, a_max=255f) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %974 = cast(%973, dtype="uint8") /* ty=Tensor[(1, 3, 64, 64, 8), uint8] */;
    %975 = cast(%974, dtype="int32") /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %976 = subtract(%975, 63 /* ty=int32 */) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %977 = fixed_point_multiply(%976, multiplier=1466124156, shift=0) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %978 = cast(%p621, dtype="int32") /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %979 = subtract(%978, 64 /* ty=int32 */) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %980 = fixed_point_multiply(%979, multiplier=1817801142, shift=0) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %981 = add(64 /* ty=int32 */, %977) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %982 = add(64 /* ty=int32 */, %980) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %983 = add(%981, %982) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %984 = subtract(%983, 64 /* ty=int32 */) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    %985 = clip(%984, a_min=0f, a_max=255f) /* ty=Tensor[(1, 3, 64, 64, 8), int32] */;
    cast(%985, dtype="uint8") /* ty=Tensor[(1, 3, 64, 64, 8), uint8] */
  };

Ah yes. This is another gotcha.

For platform that doesn’t have fast int8 instruction, we fall back to doing convolution in 16 bit. Can you try target = "llvm -mcpu=core-avx2" (if you have TVM newer than https://github.com/apache/tvm/pull/8897) or target = "llvm -mcpu=skylake-avx512".

Also note that the output shape of your network, after converted to TVM, is (1, 3, 64, 64, 8). This is because for x86 targets, we use NCHWc layout. So you need to convert the output back to NCHW.

Hi, masahi: Do you know which pass converts the float32 weight back to int8? Thanks.

FoldConstant will do the constant folding to quantize weights at compile time. You can also use the new option added in https://github.com/apache/tvm/pull/9135 to get int8 weights directly from the frontend.

1 Like

@masahi When you do the inference of pytorch qat model on tvm and pytorch, did you aware the difference between them?

Yes there will be differences, due to slight differences in how quantization numerics are done in TVM and PyTorch. See for example https://github.com/apache/tvm/blob/main/tests/python/frontend/pytorch/qnn_test.py#L449-L505

yes that’s what I am looking for. I am woring on a pixel-wise model. Is there any data for such kind of model?

Also, what’s the mainly difference between fbgemm and qnnpack, which one is recommneded by TVM? Thanks

If your model output is a discrete label rather than raw floating point values, I expect no difference between TVM and PyTorch. In other words, argmin / argmax should be stable regardless of the raw floating point output.

I always used fbgemm, I don’t know how qnnpack works.