That’s a good suggestion, and will improve the performance of all of my int8 models.
However, it seems that the issue still occurs when I apply the optimization passes, as can be seen in this updated gist that applies the passes.
EDIT
I have now further simplified the “single block” DenseNet, in this updated gist. By default, the first denseblock has 6 “dense layers” (which is a collection on conv2d, relu, and concat layers, not a fully-connected layer), but I can reproduce the bug when I have this value as low as 2.
The bug also happens with the batch norm optimizations applied.
The IR for the optimized (but non-quantized) model is:
def @main(%input0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.input0:0:0 */) -> Tensor[(1, 32, 56, 56), float32] {
%0 = nn.conv2d(%input0, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::_convolution_0:0:0 */;
%1 = multiply(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
%3 = nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] span=aten::relu__0:0:0 */;
%4 = nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::max_pool2d_0:0:0 */;
%5 = (%4,) /* ty=(Tensor[(1, 64, 56, 56), float32],) span=aten::cat_0:0:0 */;
%6 = concatenate(%5, axis=1) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::cat_0:0:0 */;
%7 = multiply(%6, meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%8 = add(%7, meta[relay.Constant][4] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 56, 56), float32] */;
%9 = nn.relu(%8) /* ty=Tensor[(1, 64, 56, 56), float32] span=aten::relu__1:0:0 */;
%10 = nn.conv2d(%9, meta[relay.Constant][5] /* ty=Tensor[(128, 64, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::_convolution_1:0:0 */;
%11 = multiply(%10, meta[relay.Constant][6] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
%12 = add(%11, meta[relay.Constant][7] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
%13 = nn.relu(%12) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::relu__2:0:0 */;
%14 = nn.conv2d(%13, meta[relay.Constant][8] /* ty=Tensor[(32, 128, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 56, 56), float32] span=aten::_convolution_2:0:0 */;
%15 = (%4, %14) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(1, 32, 56, 56), float32]) span=aten::cat_1:0:0 */;
%16 = concatenate(%15, axis=1) /* ty=Tensor[(1, 96, 56, 56), float32] span=aten::cat_1:0:0 */;
%17 = multiply(%16, meta[relay.Constant][9] /* ty=Tensor[(96, 1, 1), float32] */) /* ty=Tensor[(1, 96, 56, 56), float32] */;
%18 = add(%17, meta[relay.Constant][10] /* ty=Tensor[(96, 1, 1), float32] */) /* ty=Tensor[(1, 96, 56, 56), float32] */;
%19 = nn.relu(%18) /* ty=Tensor[(1, 96, 56, 56), float32] span=aten::relu__3:0:0 */;
%20 = nn.conv2d(%19, meta[relay.Constant][11] /* ty=Tensor[(128, 96, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1]) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::_convolution_3:0:0 */;
%21 = multiply(%20, meta[relay.Constant][12] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
%22 = add(%21, meta[relay.Constant][13] /* ty=Tensor[(128, 1, 1), float32] */) /* ty=Tensor[(1, 128, 56, 56), float32] */;
%23 = nn.relu(%22) /* ty=Tensor[(1, 128, 56, 56), float32] span=aten::relu__4:0:0 */;
nn.conv2d(%23, meta[relay.Constant][14] /* ty=Tensor[(32, 128, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 56, 56), float32] span=aten::_convolution_4:0:0 */
}
data types int8 and int32 do not match in BroadcastRel
data types int8 and int32 do not match in BroadcastRel
Again, my intutition is telling me that bug is caused by the concat %16 = concatenate(%15, axis=1) /* ty=Tensor[(1, 96, 56, 56), float32] span=aten::cat_1:0:0 */;
, but I haven’t gone deep enough into TVM’s quantization conversion to be sure.