[bug][qnn] Type mismatch in BroadcastRel for 8-bit quantized model

When attempting to run inference with an 8-bit quantized version of DenseNet (PyTorch implementation), I found that the model crashed, which you can reproduce with this gist.

The crash happens at execution time, rather than quantization time, with the error being printed as:

data types int8 and int32 do not match in BroadcastRel

I’m pretty sure that the error is being triggered from this line in relay/op/type_relations.cc.

I’m not seeing a BroadcastRel op in the IR, though I think that it could be the “concatenate” operations that are triggering it:

  %21 = nn.batch_norm(%20, %aten::batch_norm_4.weight, %aten::batch_norm_4.bias, %aten::batch_norm_4.running_mean, %aten::batch_norm_4.running_var) /* span=aten::batch_norm_4:0:0 */;
  %22 = %21.0 /* span=aten::batch_norm_4:0:0 */;
  %23 = nn.relu(%22) /* span=aten::relu__4:0:0 */;
  %24 = nn.conv2d(%23, %aten::_convolution_4.weight, padding=[1, 1, 1, 1], channels=48, kernel_size=[3, 3]) /* span=aten::_convolution_4:0:0 */;
  %25 = (%4, %14, %24) /* span=aten::cat_2:0:0 */;
  %26 = concatenate(%25, axis=1) /* span=aten::cat_2:0:0 */;

My reasoning being that (%4, %14, %24) might not all be the same type because of inconsistent re-quantization steps.

Unfortunately I can’t verify this, as I don’t have a way of printing a more verbose IR with the types etc for quantized models. I can do this fine for float32 models, if you look at my print_tir function.

Any tips on going deeper with this?

The link to the gist points to a wrong destination. Can you fix it?

Ah thanks, my mistake, I’ve fixed it, and here is the gist again for good measure.

I also have a 2nd version of the script, where I’ve isolated a single dense block of the model, which exhibits the same issue. I’m going to try and reduce it further, and try and prove with differential testing that the concat is the issue, but it’s getting late in my timezone. The current version of that script is here, and may be useful.

I have a feeling that the error might be related to batch_norm. In general it is recommended to remove batch norm before quantizing, see https://github.com/apache/tvm/blob/553eb1acd0c115adea0c7d04ce36e26332339769/tests/python/relay/test_pass_fold_constant.py#L317-L324

That’s a good suggestion, and will improve the performance of all of my int8 models.

However, it seems that the issue still occurs when I apply the optimization passes, as can be seen in this updated gist that applies the passes.

EDIT

I have now further simplified the “single block” DenseNet, in this updated gist. By default, the first denseblock has 6 “dense layers” (which is a collection on conv2d, relu, and concat layers, not a fully-connected layer), but I can reproduce the bug when I have this value as low as 2.

The bug also happens with the batch norm optimizations applied.

The IR for the optimized (but non-quantized) model is:

def @main(%input0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.input0:0:0 */) -> Tensor[(1, 32, 56, 56), float32] {
  %0 = nn.conv2d(%input0, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::_convolution_0:0:0 */;
  %1 = multiply(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */;
  %4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::max_pool2d_0:0:0 */;
  %5 = (%4,) /* ty=(Tensor[(1, 64, 56, 56), float32],) span=aten::cat_0:0:0 */;
  %6 = concatenate(%5, axis=1) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::cat_0:0:0 */;
  %7 = multiply(%6, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %8 = add(%7, meta[relay.Constant][4] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %9 = nn.relu(%8) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::relu__1:0:0 */;
  %10 = nn.conv2d(%9, meta[relay.Constant][5] /* ty=Tensor[(128, 64, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::_convolution_1:0:0 */;
  %11 = multiply(%10, meta[relay.Constant][6] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
  %12 = add(%11, meta[relay.Constant][7] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
  %13 = nn.relu(%12) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::relu__2:0:0 */;
  %14 = nn.conv2d(%13, meta[relay.Constant][8] /* ty=Tensor[(32, 128, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 56, 56), float32] span=aten::_convolution_2:0:0 */;
  %15 = (%4, %14) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(1, 32, 56, 56), float32]) span=aten::cat_1:0:0 */;
  %16 = concatenate(%15, axis=1) /* ty=Tensor[(1, 96, 56, 56), float32] span=aten::cat_1:0:0 */;
  %17 = multiply(%16, meta[relay.Constant][9] /* ty=Tensor[(96, 1, 1), float32] */) /* ty=Tensor[(1, 96, 56, 56), float32] */;
  %18 = add(%17, meta[relay.Constant][10] /* ty=Tensor[(96, 1, 1), float32] */) /* ty=Tensor[(1, 96, 56, 56), float32] */;
  %19 = nn.relu(%18) /* ty=Tensor[(1, 96, 56, 56), float32] span=aten::relu__3:0:0 */;
  %20 = nn.conv2d(%19, meta[relay.Constant][11] /* ty=Tensor[(128, 96, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::_convolution_3:0:0 */;
  %21 = multiply(%20, meta[relay.Constant][12] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
  %22 = add(%21, meta[relay.Constant][13] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
  %23 = nn.relu(%22) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::relu__4:0:0 */;
  nn.conv2d(%23, meta[relay.Constant][14] /* ty=Tensor[(32, 128, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 56, 56), float32] span=aten::_convolution_4:0:0 */
}


data types int8 and int32 do not match in BroadcastRel
data types int8 and int32 do not match in BroadcastRel

Again, my intutition is telling me that bug is caused by the concat %16 = concatenate(%15, axis=1) /* ty=Tensor[(1, 96, 56, 56), float32] span=aten::cat_1:0:0 */;, but I haven’t gone deep enough into TVM’s quantization conversion to be sure.

The qnn’s BroadcastRel defined for “sum”, “mul” and “subtract” ops. Please check if all inputs have the same data type for these operations. The provided IR does not allow to do this validation.

I met the same error. Can anyone give some ideas?

I don’t have a fix, but I fixed a different QNN bug a couple of months back, you can see the changes here. This could give you an idea of where you might want to start looking.

Adding a data-type check to one of the operations could help with this.