[QNN]How to add elementwise op fusing in qnn?

HI,

I test convs and dense ops following by some special activation like: leakyrelu, swish, hswish on tvm v0.9 master branch. From the function_metadata returned by relay.build, they cannot fused together like fp32. How to add this kind of fusing?

@masahi could you give me some suggestions, thanks a lot!

Can you give a concrete example of this? I expect the fusion you described for QNN to work just as well as fp32.

OK, I will describe in detail next.

Rencntly, PR#11729 add quantilized::leaky_relu, but it fallbacks to fp32 when excuting.

I designed a quantized torch model in which one conv follows leakyrelu. After relay building with target llvm -mcpu=core-avx2, the concrete parameters of this part FunctionInfoNode are like:

  relay_primfuncs={llvm -keys=cpu -link-params=0 -mcpu=core-avx2: fn (%p0: Tensor[(128, 1, 640, 64, 4), uint8] /* ty=Tensor[(128, 1, 640, 64, 4), uint8] */, %p1: Tensor[(8, 1, 4, 64, 1, 16, 4), int8] /* ty=Tensor[(8, 1, 4, 64, 1, 16, 4), int8] */, %p2: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p3: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p4: Tensor[(1, 8, 1, 1, 16), float32] /* ty=Tensor[(1, 8, 1, 1, 16), float32] */, %p5: float32 /* ty=float32 */, target=meta[Target][0], prim_funcs={'tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_sub_dec13fea0b694606__2'=meta[tir.PrimFunc][0]}, out_layout="NCHW16c", data_layout="NCHW4c", hash="cfb467205b74da7e", kernel_layout="OIHW1i16o4i", prim_fn_var='tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_sub_dec13fea0b694606__2', Primitive=1) -> Tensor[(128, 8, 637, 1, 16), uint8] {
  %0 = nn.contrib_conv2d_NCHWc(%p0, %p1, padding=[0, 0, 0, 0], channels=128, kernel_size=[4, 64], data_layout="NCHW4c", kernel_layout="OIHW1i16o4i", out_layout="NCHW16c", out_dtype="int32") /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %1 = subtract(%0, %p2) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %2 = add(%1, %p3) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %3 = cast(%2, dtype="float32") /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %4 = multiply(%3, %p4) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %5 = add(%4, 102.5f /* ty=float32 */) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %6 = floor(%5) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %7 = cast(%6, dtype="int32") /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %8 = clip(%7, a_min=0f, a_max=255f) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %9 = subtract(%8, 102 /* ty=int32 */) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %10 = cast(%9, dtype="float32") /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %11 = multiply(%10, 0.177173f /* ty=float32 */) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %12 = nn.leaky_relu(%11, alpha=0.01f) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %13 = divide(%12, 0.0350804f /* ty=float32 */) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %14 = round(%13) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %15 = add(%14, %p5) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %16 = clip(%15, a_min=0f, a_max=255f) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  cast(%16, dtype="uint8") /* ty=Tensor[(128, 8, 637, 1, 16), uint8] */
} /* ty=fn (Tensor[(128, 1, 640, 64, 4), uint8], Tensor[(8, 1, 4, 64, 1, 16, 4), int8], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), float32], float32) -> Tensor[(128, 8, 637, 1, 16), uint8] */
}), 

recently, they are updated by PR#11930 and PR#12116, while leakyrelu is excuting in int8 with lookup table. The concrete parameters of this part FunctionInfoNode changed like:

  relay_primfuncs={llvm -keys=cpu -link-params=0 -mcpu=core-avx2: fn (%p0: Tensor[(128, 1, 640, 64, 4), uint8] /* ty=Tensor[(128, 1, 640, 64, 4), uint8] */, %p1: Tensor[(8, 1, 4, 64, 1, 16, 4), int8] /* ty=Tensor[(8, 1, 4, 64, 1, 16, 4), int8] */, %p2: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p3: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p4: Tensor[(1, 8, 1, 1, 16), float32] /* ty=Tensor[(1, 8, 1, 1, 16), float32] */, target=meta[Target][0], prim_funcs={'tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_2'=meta[tir.PrimFunc][0]}, out_layout="NCHW16c", data_layout="NCHW4c", hash="cf360873df93e020", kernel_layout="OIHW1i16o4i", prim_fn_var='tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_2', Primitive=1) -> Tensor[(128, 8, 637, 1, 16), int32] {
  %0 = nn.contrib_conv2d_NCHWc(%p0, %p1, padding=[0, 0, 0, 0], channels=128, kernel_size=[4, 64], data_layout="NCHW4c", kernel_layout="OIHW1i16o4i", out_layout="NCHW16c", out_dtype="int32") /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %1 = subtract(%0, %p2) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %2 = add(%1, %p3) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %3 = cast(%2, dtype="float32") /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %4 = multiply(%3, %p4) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %5 = add(%4, 102.5f /* ty=float32 */) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %6 = floor(%5) /* ty=Tensor[(128, 8, 637, 1, 16), float32] */;
  %7 = cast(%6, dtype="int32") /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  clip(%7, a_min=0f, a_max=255f) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */
} /* ty=fn (Tensor[(128, 1, 640, 64, 4), uint8], Tensor[(8, 1, 4, 64, 1, 16, 4), int8], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), float32]) -> Tensor[(128, 8, 637, 1, 16), int32] */
}), 

The results show that qnn leaky relu cannot fused with conv again if excuting in int8 mode with lookup table. So how to process them to make quantized elementwise op like leakyrelu or hardswish fused with complex-out operators like conv or dense? Should I register new qnn operator with kElemWise of OpPatternKind or how?

Thanks for your kind reply.

I see, so you want to undo the change by https://github.com/apache/tvm/pull/11930? I can imagine that a table look-up based op can be tricky for op fusion.

You can always use Relay pattern match and rewrite to replace relay.qnn.op.leaky_relu with the fp32 fallback version.

Because if activation can excute in int8 instead of fp32 while the accuracy changes a little, It can bring great inference performance improvement in our special designed model. PR#11930 and PR#12116 are provided by our group member.

A natural idea is that we want to continue to further improve inference performance through quantized operator fusion. It is possible for a table look-up based op to do op fusiion? What is the difference between them and fp32 op? We want to continue to contribute to the TVM in this regard. Do you have some suggestions? Thanks.

1 Like

How do you know that qnn.leaky_relu is implemented as table lookup? I don’t see it in https://github.com/apache/tvm/blob/97b3076c3532f73a9d9eeba26a3f329f8e0f803d/python/tvm/relay/qnn/op/legalizations.py#L76-L83

Sorry, I gave a inappropriate example. In the latest code, the hardswish and swish(including sigmoid) use lookup table while leakyrelu not.

In the examples below, I noticed qnn::leakyrelu and qnn::hardswish are replaced with some meta exprs. I guess they can not be excuted together with conv because TOpPattern of some Exprs(like layout_transform) are no longer belong to kElemWise. Dose it true?

If true, does it mean that it is impossible to do complete int8 activation fusion like fp32?

  • Leakyrelu
# conv
 relay_primfuncs={llvm -keys=cpu -link-params=0 -mcpu=core-avx2: fn (%p0: Tensor[(128, 1, 640, 64, 4), uint8] /* ty=Tensor[(128, 1, 640, 64, 4), uint8] */, %p1: Tensor[(8, 1, 2, 64, 1, 16, 4), int8] /* ty=Tensor[(8, 1, 2, 64, 1, 16, 4), int8] */, %p2: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p3: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p4: Tensor[(1, 8, 1, 1, 16), float32] /* ty=Tensor[(1, 8, 1, 1, 16), float32] */, target=meta[Target][0], prim_funcs={'tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip'=meta[tir.PrimFunc][0]}, out_layout="NCHW16c", data_layout="NCHW4c", hash="19115e71f466e294", kernel_layout="OIHW1i16o4i", prim_fn_var='tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip', Primitive=1) -> Tensor[(128, 8, 639, 1, 16), int32] {
  %0 = nn.contrib_conv2d_NCHWc(%p0, %p1, padding=[0, 0, 0, 0], channels=128, kernel_size=[2, 64], data_layout="NCHW4c", kernel_layout="OIHW1i16o4i", out_layout="NCHW16c", out_dtype="int32") /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %1 = subtract(%0, %p2) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %2 = add(%1, %p3) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %3 = cast(%2, dtype="float32") /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %4 = multiply(%3, %p4) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %5 = add(%4, 103.5f /* ty=float32 */) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %6 = floor(%5) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %7 = cast(%6, dtype="int32") /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  clip(%7, a_min=0f, a_max=255f) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */
} /* ty=fn (Tensor[(128, 1, 640, 64, 4), uint8], Tensor[(8, 1, 2, 64, 1, 16, 4), int8], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), float32]) -> Tensor[(128, 8, 639, 1, 16), int32] */

# leakyrelu
 relay_primfuncs={llvm -keys=cpu -link-params=0 -mcpu=core-avx2: fn (%p0: Tensor[(128, 8, 637, 1, 16), int32] /* ty=Tensor[(128, 8, 637, 1, 16), int32] */, %p1: int32 /* ty=int32 */, %p2: int32 /* ty=int32 */, src_layout="NCHW16c", hash="1e03279bb39a6741", prim_funcs={'tvmgen_default_fused_less_layout_transform_fixed_point_multiply_add_layout_transform_layout_tra_fd75043815579448__2'=meta[tir.PrimFunc][0]}, dst_layout="NCHW", Primitive=1, prim_fn_var='tvmgen_default_fused_less_layout_transform_fixed_point_multiply_add_layout_transform_layout_tra_fd75043815579448__2', target=meta[Target][0]) -> Tensor[(128, 128, 637, 1), uint8] {
  %0 = less(%p0, %p1) /* ty=Tensor[(128, 8, 637, 1, 16), bool] */;
  %1 = fixed_point_multiply(%p0, multiplier=1374389535, shift=-6) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %2 = add(%1, %p2) /* ty=Tensor[(128, 8, 637, 1, 16), int32] */;
  %3 = layout_transform(%0, src_layout="NCHW16c", dst_layout="NCHW") /* ty=Tensor[(128, 128, 637, 1), bool] */;
  %4 = layout_transform(%2, src_layout="NCHW16c", dst_layout="NCHW") /* ty=Tensor[(128, 128, 637, 1), int32] */;
  %5 = layout_transform(%p0, src_layout="NCHW16c", dst_layout="NCHW") /* ty=Tensor[(128, 128, 637, 1), int32] */;
  %6 = where(%3, %4, %5) /* ty=Tensor[(128, 128, 637, 1), int32] */;
  %7 = clip(%6, a_min=0f, a_max=255f) /* ty=Tensor[(128, 128, 637, 1), int32] */;
  cast(%7, dtype="uint8") /* ty=Tensor[(128, 128, 637, 1), uint8] */
} /* ty=fn (Tensor[(128, 8, 637, 1, 16), int32], int32, int32) -> Tensor[(128, 128, 637, 1), uint8] */
  • Hardswish
# conv
  relay_primfuncs={llvm -keys=cpu -link-params=0 -mcpu=core-avx2: fn (%p0: Tensor[(128, 1, 640, 64, 4), uint8] /* ty=Tensor[(128, 1, 640, 64, 4), uint8] */, %p1: Tensor[(8, 1, 2, 64, 1, 16, 4), int8] /* ty=Tensor[(8, 1, 2, 64, 1, 16, 4), int8] */, %p2: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p3: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p4: Tensor[(1, 8, 1, 1, 16), float32] /* ty=Tensor[(1, 8, 1, 1, 16), float32] */, target=meta[Target][0], prim_funcs={'tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_cas_db782882eb0ba5d2_'=meta[tir.PrimFunc][0]}, out_layout="NCHW16c", data_layout="NCHW4c", hash="036bd901056bd0fa", kernel_layout="OIHW1i16o4i", prim_fn_var='tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_cas_db782882eb0ba5d2_', Primitive=1) -> Tensor[(128, 8, 639, 1, 16), uint8] {
  %0 = nn.contrib_conv2d_NCHWc(%p0, %p1, padding=[0, 0, 0, 0], channels=128, kernel_size=[2, 64], data_layout="NCHW4c", kernel_layout="OIHW1i16o4i", out_layout="NCHW16c", out_dtype="int32") /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %1 = subtract(%0, %p2) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %2 = add(%1, %p3) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %3 = cast(%2, dtype="float32") /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %4 = multiply(%3, %p4) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %5 = add(%4, 68.5f /* ty=float32 */) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %6 = floor(%5) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %7 = cast(%6, dtype="int32") /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %8 = clip(%7, a_min=0f, a_max=255f) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
# hardswish
  %9 = cast(%8, dtype="uint8") /* ty=Tensor[(128, 8, 639, 1, 16), uint8] */;
  reinterpret(%9, dtype="uint8") /* ty=Tensor[(128, 8, 639, 1, 16), uint8] */

  relay_primfuncs={llvm -keys=cpu -link-params=0 -mcpu=core-avx2: fn (%p0: Tensor[(128, 8, 639, 1, 16), uint8] /* ty=Tensor[(128, 8, 639, 1, 16), uint8] */, %p1: Tensor[(256), uint8] /* ty=Tensor[(256), uint8] */, src_layout="NCHW16c", hash="8292eeffb256f4b0", prim_funcs={'tvmgen_default_fused_layout_transform_take'=meta[tir.PrimFunc][0]}, dst_layout="NCHW", Primitive=1, prim_fn_var='tvmgen_default_fused_layout_transform_take', target=meta[Target][0]) -> Tensor[(128, 128, 639, 1), uint8] {
  %0 = layout_transform(%p0, src_layout="NCHW16c", dst_layout="NCHW") /* ty=Tensor[(128, 128, 639, 1), uint8] */;
  take(%p1, %0, axis=0, mode="fast") /* ty=Tensor[(128, 128, 639, 1), uint8] */
} /* ty=fn (Tensor[(128, 8, 639, 1, 16), uint8], Tensor[(256), uint8]) -> Tensor[(128, 128, 639, 1), uint8] */

Yes, that’s the reason fusion breaks at that boundary. But I don’t understand why layout_transform are generated for leakyrelu case, because the layout shouldn’t matter for where op. Can you try adding

    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

after https://github.com/apache/tvm/blob/111169c7df2831ab8ee40d5388ebcfcf551fd86f/src/relay/op/tensor/transform.cc#L2275?

For hardswish and other lookup based ops, I don’t think we can fuse them to conv now.

This line code dose help to leakyrelu, the same block code change to format like:

  relay_primfuncs={llvm -keys=cpu -link-params=0 -mcpu=core-avx2: fn (%p0: Tensor[(128, 1, 640, 64, 4), uint8] /* ty=Tensor[(128, 1, 640, 64, 4), uint8] */, %p1: Tensor[(8, 1, 2, 64, 1, 16, 4), int8] /* ty=Tensor[(8, 1, 2, 64, 1, 16, 4), int8] */, %p2: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p3: Tensor[(1, 8, 1, 1, 16), int32] /* ty=Tensor[(1, 8, 1, 1, 16), int32] */, %p4: Tensor[(1, 8, 1, 1, 16), float32] /* ty=Tensor[(1, 8, 1, 1, 16), float32] */, %p5: int32 /* ty=int32 */, %p6: int32 /* ty=int32 */, target=meta[Target][0], prim_funcs={'tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_les_7095284e9dd4134d_'=meta[tir.PrimFunc][0]}, out_layout="NCHW16c", data_layout="NCHW4c", hash="9e0a0dee186a9fde", kernel_layout="OIHW1i16o4i", prim_fn_var='tvmgen_default_fused_nn_contrib_conv2d_NCHWc_subtract_add_cast_multiply_add_floor_cast_clip_les_7095284e9dd4134d_', Primitive=1) -> Tensor[(128, 8, 639, 1, 16), uint8] {
  %0 = nn.contrib_conv2d_NCHWc(%p0, %p1, padding=[0, 0, 0, 0], channels=128, kernel_size=[2, 64], data_layout="NCHW4c", kernel_layout="OIHW1i16o4i", out_layout="NCHW16c", out_dtype="int32") /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %1 = subtract(%0, %p2) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %2 = add(%1, %p3) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %3 = cast(%2, dtype="float32") /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %4 = multiply(%3, %p4) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %5 = add(%4, 103.5f /* ty=float32 */) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %6 = floor(%5) /* ty=Tensor[(128, 8, 639, 1, 16), float32] */;
  %7 = cast(%6, dtype="int32") /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %8 = clip(%7, a_min=0f, a_max=255f) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %9 = fixed_point_multiply(%8, multiplier=1374389535, shift=-6) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %10 = less(%8, %p5) /* ty=Tensor[(128, 8, 639, 1, 16), bool] */;
  %11 = add(%9, %p6) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %12 = where(%10, %11, %8) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  %13 = clip(%12, a_min=0f, a_max=255f) /* ty=Tensor[(128, 8, 639, 1, 16), int32] */;
  cast(%13, dtype="uint8") /* ty=Tensor[(128, 8, 639, 1, 16), uint8] */
} /* ty=fn (Tensor[(128, 1, 640, 64, 4), uint8], Tensor[(8, 1, 2, 64, 1, 16, 4), int8], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), int32], Tensor[(1, 8, 1, 1, 16), float32], int32, int32) -> Tensor[(128, 8, 639, 1, 16), uint8] */
}), 

I think I understand the problem of op fusing with you kind help, but still have another doubt. Why are the Exprs in a int8 model executed in fp32 and int32? Dequantization and quantization will definitely affect the speed of inference. What are the considerations for this design?

Dequantize / quantize things are typically for the case where there is no good way to execute quantized ops entirely in int8. So they are just a fallback, there is no deep design decision. This is also what PyTorch or ONNXRuntime often do.

I have understood the relevant questions, thank you for your patience.