When a QAT distilbert int8 model (tflite > onnx > tvm) is imported in tvm, we see that qnn batch matmul operation has been broken down into strided_slice and qnn dense operations. Converting the below pattern to qnn batch matmul would reduce the ir by 65%. Could you please comment if such a relay transformation would be useful upstream?
Pattern:
%35 = strided_slice(%34, begin=[0i64, 0i64, 0i64], end=[1i64, 128i64, 64i64], strides=[1i64, 1i64, 1i64], axes=None) /* ty=Tensor[(1, 128, 64), int8] */;
%44 = reshape(%35, newshape=[-1, 64]) /* ty=Tensor[(128, 64), int8] */;
%42 = strided_slice(%41, begin=[0i64, 0i64, 0i64], end=[1i64, 64i64, 128i64], strides=[1i64, 1i64, 1i64], axes=None) /* ty=Tensor[(1, 64, 128), int8] */;
%43 = reshape(%42, newshape=[64, 128]) /* ty=Tensor[(64, 128), int8] */;
%45 = transpose(%43, axes=[1, 0]) /* ty=Tensor[(128, 64), int8] */;
%46 = qnn.dense(%44, %45, 13 /* ty=int32 /, 1 / ty=int32 /, 0.0541715f / ty=float32 /, 0.0489368f / ty=float32 /, units=None, out_dtype=“int32”) / ty=Tensor[(128, 128), int32] */;
%47 = qnn.requantize(%46, 0.00265098f /* ty=float32 /, 0 / ty=int32 /, 0.728874f / ty=float32 /, -14 / ty=int32 /, axis=1, out_dtype=“int8”) / ty=Tensor[(128, 128), int8] */;
%125 = expand_dims(%47, axis=0) /* ty=Tensor[(1, 128, 128), int8] */;
< The above pattern repeats 12 times, which is the batch size >
%137 = (%125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136) /* ty=(Tensor[(1, 128, 128), int8], … , Tensor[(1, 128, 128), int8]) */;
%138 = concatenate(%137) /* ty=Tensor[(12, 128, 128), int8] */;
Resulting pattern:
%41 = transpose(%39, axes=[0, 2, 1]);
%42 = qnn.batch_matmul(%40, %41, 13 /* ty=int32 /, 1 / ty=int32 /, 0.0541715f / ty=float32 /, 0.0489368f / ty=float32 */, out_dtype=“int32”, transpose_b=True);
%43 = qnn.requantize(%42, 0.00265098f /* ty=float32 /, 0 / ty=int32 /, 0.728874f / ty=float32 /, -14 / ty=int32 */, out_dtype=“int8”);