Optimizing QAT distilbert int8 model

When a QAT distilbert int8 model (tflite > onnx > tvm) is imported in tvm, we see that qnn batch matmul operation has been broken down into strided_slice and qnn dense operations. Converting the below pattern to qnn batch matmul would reduce the ir by 65%. Could you please comment if such a relay transformation would be useful upstream?

Pattern:

%35 = strided_slice(%34, begin=[0i64, 0i64, 0i64], end=[1i64, 128i64, 64i64], strides=[1i64, 1i64, 1i64], axes=None) /* ty=Tensor[(1, 128, 64), int8] */;

%44 = reshape(%35, newshape=[-1, 64]) /* ty=Tensor[(128, 64), int8] */;

%42 = strided_slice(%41, begin=[0i64, 0i64, 0i64], end=[1i64, 64i64, 128i64], strides=[1i64, 1i64, 1i64], axes=None) /* ty=Tensor[(1, 64, 128), int8] */;

%43 = reshape(%42, newshape=[64, 128]) /* ty=Tensor[(64, 128), int8] */;

%45 = transpose(%43, axes=[1, 0]) /* ty=Tensor[(128, 64), int8] */;

%46 = qnn.dense(%44, %45, 13 /* ty=int32 /, 1 / ty=int32 /, 0.0541715f / ty=float32 /, 0.0489368f / ty=float32 /, units=None, out_dtype=“int32”) / ty=Tensor[(128, 128), int32] */;

%47 = qnn.requantize(%46, 0.00265098f /* ty=float32 /, 0 / ty=int32 /, 0.728874f / ty=float32 /, -14 / ty=int32 /, axis=1, out_dtype=“int8”) / ty=Tensor[(128, 128), int8] */;

%125 = expand_dims(%47, axis=0) /* ty=Tensor[(1, 128, 128), int8] */;

< The above pattern repeats 12 times, which is the batch size >

%137 = (%125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136) /* ty=(Tensor[(1, 128, 128), int8], … , Tensor[(1, 128, 128), int8]) */;

%138 = concatenate(%137) /* ty=Tensor[(12, 128, 128), int8] */;

Resulting pattern:

%41 = transpose(%39, axes=[0, 2, 1]);

%42 = qnn.batch_matmul(%40, %41, 13 /* ty=int32 /, 1 / ty=int32 /, 0.0541715f / ty=float32 /, 0.0489368f / ty=float32 */, out_dtype=“int32”, transpose_b=True);

%43 = qnn.requantize(%42, 0.00265098f /* ty=float32 /, 0 / ty=int32 /, 0.728874f / ty=float32 /, -14 / ty=int32 */, out_dtype=“int8”);

cc: @AndrewZhaoLuo @masahi @kparzysz @jverma

Yes this looks very useful. I wonder where that “unrolling” of int8 bmm happens - it looks super inefficient.

1 Like

Hi Masahi, This is a tflite QAT model we received. I think the original tensorflow float model has batch matmul and this unrolling could be the result of QAT process on the float model. We see batchmatmul in the float model, but not in the QAT model.