Pytorch model inference time way faster than autotuning TVM

Hi,

I have create a simple code to auto tuning a quantized pytorch model and run in the raspberry pi using ACL support, but the tvm model is way slower than the pytorch. Pytorch inference time is 25 ms and tvm is 38 ms. Does anyone know if it is the expected result or if there is any way to create a better inference time? I was thinking that pvm can improve the inference time in a specific target, but so far this is not the case.

I remember that, PT runs quantized models on ARM very fast thanks to qnnpack. But when I tested it before, I remember PT didn’t support using multiple threads on ARM, so in a multicore scenario TVM was faster.

What do you mean by “using ACL support”? If you use ACL for TVM, there is no auto tuning.

Hi @i02fesea,

In addition to @masahi’s questions would it be possible to share the Relay IR that is passed to relay.build?

Hi Masahi,

Thanks for your answer, do you have any code to support this? “I remember PT didn’t support using multiple threads on ARM, so in a multicore scenario TVM was faster” Pytorch support multi core and multi threading if I am not mistaken CPU threading and TorchScript inference — PyTorch 1.11.0 documentation

“If you use ACL for TVM, there is no auto tuning”, so if I understand what you are saying, if I compile an ACL runtime module, I don’t need to auto tuning the module/model, right?

Thanks for the help

Hi Ihutton,

Thanks for your help. This is part of my code:

def tune_and_evaluate(tuning_opt, mod, params, input_shape, input_name):
    # extract workloads from relay program
    tasks = autotvm.task.extract_from_program(
        mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),))

    tune_tasks(tasks, **tuning_opt)

    # compile kernels with history best records
    with autotvm.apply_history_best(log_file):
        print("Compile...")
        with tvm.transform.PassContext(opt_level=3):
            lib = relay.build_module.build(mod, target=target, params=params)

        filename = "net_dep_and_opt.tar"
        # lib.export_library(tmp.relpath(filename))
        lib.export_library(pathSaveLib + filename)


        # upload module to device
        print("Upload...")
        remote = autotvm.measure.request_remote(device_key, hostIP, 9190, timeout=10000)
        remote.upload(pathSaveLib + filename)

        rlib = remote.load_module(filename)

        # upload parameters to device
        dev = remote.device(str(target), 0)
        module = runtime.GraphModule(rlib["default"](dev))
        data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
        module.set_input(input_name, data_tvm)

        print("Evaluate inference time cost...")
        print(module.benchmark(dev, number=1, repeat=10))

Thanks @i02fesea, I’m actually more interested in the IR you see when you add print(mod) before this line:

If you’re able to share that output it would help give us an idea about how the graph is being partitioned.

To answer your question here, if the whole of your graph is offloaded to ACL, no tuning will be necessary since ACL kernels cannot be tuned. However, if only part of the graph is offloaded to ACL, the rest of the graph that is offloaded to TVM can be tuned.

Hi lhutton,

this is the output, sorry I have to cut it, since it is very long:

    def @main(%input: Tensor[(1, 3, 224, 224), float32], %features.0.0_weight: Tensor[(32, 3, 3, 3), float32], %features.0.0_bias: Tensor[(32), float32], %features.1.conv.0.0_weight: Tensor[(32, 1, 3, 3), float32], %features.1.conv.0.0_bias: Tensor[(32), float32], %features.1.conv.1_weight: Tensor[(16, 32, 1, 1), float32], %features.1.conv.1_bias: Tensor[(16), float32], %features.2.conv.0.0_weight: Tensor[(96, 16, 1, 1), float32], %features.2.conv.0.0_bias: Tensor[(96), float32], %features.2.conv.1.0_weight: Tensor[(96, 1, 3, 3), float32], %features.2.conv.1.0_bias: Tensor[(96), float32], %features.2.conv.2_weight: Tensor[(24, 96, 1, 1), float32], %features.2.conv.2_bias: Tensor[(24), float32], %features.3.conv.0.0_weight: Tensor[(144, 24, 1, 1), float32], %features.3.conv.0.0_bias: Tensor[(144), float32], %features.3.conv.1.0_weight: Tensor[(144, 1, 3, 3), float32], %features.3.conv.1.0_bias: Tensor[(144), float32], %features.3.conv.2_weight: Tensor[(24, 144, 1, 1), float32], %features.3.conv.2_bias: Tensor[(24), float32], %features.4.conv.0.0_weight: Tensor[(144, 24, 1, 1), float32], %features.4.conv.0.0_bias: Tensor[(144), float32], %features.4.conv.1.0_weight: Tensor[(144, 1, 3, 3), float32], %features.4.conv.1.0_bias: Tensor[(144), float32], %features.4.conv.2_weight: Tensor[(32, 144, 1, 1), float32], %features.4.conv.2_bias: Tensor[(32), float32], %features.5.conv.0.0_weight: Tensor[(192, 32, 1, 1), float32], %features.5.conv.0.0_bias: Tensor[(192), float32], %features.5.conv.1.0_weight: Tensor[(192, 1, 3, 3), float32], %features.5.conv.1.0_bias: Tensor[(192), float32], %features.5.conv.2_weight: Tensor[(32, 192, 1, 1), float32], %features.5.conv.2_bias: Tensor[(32), float32], %features.6.conv.0.0_weight: Tensor[(192, 32, 1, 1), float32], %features.6.conv.0.0_bias: Tensor[(192), float32], %features.6.conv.1.0_weight: Tensor[(192, 1, 3, 3), float32], %features.6.conv.1.0_bias: Tensor[(192), float32], %features.6.conv.2_weight: Tensor[(32, 192, 1, 1), float32], %features.6.conv.2_bias: Tensor[(32), float32], %features.7.conv.0.0_weight: Tensor[(192, 32, 1, 1), float32], %features.7.conv.0.0_bias: Tensor[(192), float32], %features.7.conv.1.0_weight: Tensor[(192, 1, 3, 3), float32], %features.7.conv.1.0_bias: Tensor[(192), float32], %features.7.conv.2_weight: Tensor[(64, 192, 1, 1), float32], %features.7.conv.2_bias: Tensor[(64), float32], %features.8.conv.0.0_weight: Tensor[(384, 64, 1, 1), float32], %features.8.conv.0.0_bias: Tensor[(384), float32], %features.8.conv.1.0_weight: Tensor[(384, 1, 3, 3), float32], %features.8.conv.1.0_bias: Tensor[(384), float32], %features.8.conv.2_weight: Tensor[(64, 384, 1, 1), float32], %features.8.conv.2_bias: Tensor[(64), float32], %features.9.conv.0.0_weight: Tensor[(384, 64, 1, 1), float32], %features.9.conv.0.0_bias: Tensor[(384), float32], %features.9.conv.1.0_weight: Tensor[(384, 1, 3, 3), float32], %features.9.conv.1.0_bias: Tensor[(384), float32], %features.9.conv.2_weight: Tensor[(64, 384, 1, 1), float32], %features.9.conv.2_bias: Tensor[(64), float32], %features.10.conv.0.0_weight: Tensor[(384, 64, 1, 1), float32], %features.10.conv.0.0_bias: Tensor[(384), float32], %features.10.conv.1.0_weight: Tensor[(384, 1, 3, 3), float32], %features.10.conv.1.0_bias: Tensor[(384), float32], %features.10.conv.2_weight: Tensor[(64, 384, 1, 1), float32], %features.10.conv.2_bias: Tensor[(64), float32], %features.11.conv.0.0_weight: Tensor[(384, 64, 1, 1), float32], %features.11.conv.0.0_bias: Tensor[(384), float32], %features.11.conv.1.0_weight: Tensor[(384, 1, 3, 3), float32], %features.11.conv.1.0_bias: Tensor[(384), float32], %features.11.conv.2_weight: Tensor[(96, 384, 1, 1), float32], %features.11.conv.2_bias: Tensor[(96), float32], %features.12.conv.0.0_weight: Tensor[(576, 96, 1, 1), float32], %features.12.conv.0.0_bias: Tensor[(576), float32], %features.12.conv.1.0_weight: Tensor[(576, 1, 3, 3), float32], %features.12.conv.1.0_bias: Tensor[(576), float32], %features.12.conv.2_weight: Tensor[(96, 576, 1, 1), float32], %features.12.conv.2_bias: Tensor[(96), float32], %features.13.conv.0.0_weight: Tensor[(576, 96, 1, 1), float32], %features.13.conv.0.0_bias: Tensor[(576), float32], %features.13.conv.1.0_weight: Tensor[(576, 1, 3, 3), float32], %features.13.conv.1.0_bias: Tensor[(576), float32], %features.13.conv.2_weight: Tensor[(96, 576, 1, 1), float32], %features.13.conv.2_bias: Tensor[(96), float32], %features.14.conv.0.0_weight: Tensor[(576, 96, 1, 1), float32], %features.14.conv.0.0_bias: Tensor[(576), float32], %features.14.conv.1.0_weight: Tensor[(576, 1, 3, 3), float32], %features.14.conv.1.0_bias: Tensor[(576), float32], %features.14.conv.2_weight: Tensor[(160, 576, 1, 1), float32], %features.14.conv.2_bias: Tensor[(160), float32], %features.15.conv.0.0_weight: Tensor[(960, 160, 1, 1), float32], %features.15.conv.0.0_bias: Tensor[(960), float32], %features.15.conv.1.0_weight: Tensor[(960, 1, 3, 3), float32], %features.15.conv.1.0_bias: Tensor[(960), float32], %features.15.conv.2_weight: Tensor[(160, 960, 1, 1), float32], %features.15.conv.2_bias: Tensor[(160), float32], %features.16.conv.0.0_weight: Tensor[(960, 160, 1, 1), float32], %features.16.conv.0.0_bias: Tensor[(960), float32], %features.16.conv.1.0_weight: Tensor[(960, 1, 3, 3), float32], %features.16.conv.1.0_bias: Tensor[(960), float32], %features.16.conv.2_weight: Tensor[(160, 960, 1, 1), float32], %features.16.conv.2_bias: Tensor[(160), float32], %features.17.conv.0.0_weight: Tensor[(960, 160, 1, 1), float32], %features.17.conv.0.0_bias: Tensor[(960), float32], %features.17.conv.1.0_weight: Tensor[(960, 1, 3, 3), float32], %features.17.conv.1.0_bias: Tensor[(960), float32], %features.17.conv.2_weight: Tensor[(320, 960, 1, 1), float32], %features.17.conv.2_bias: Tensor[(320), float32], %features.18.0_weight: Tensor[(1280, 320, 1, 1), float32], %features.18.0_bias: Tensor[(1280), float32], %classifier.1._packed_params_weight: Tensor[(1000, 1280), float32], %classifier.1._packed_params_bias: Tensor[(1000), float32]) -> Tensor[(1, 1000), float32] {
              %0 = qnn.quantize(%input, 0.0359743f /* ty=float32 */, 54 /* ty=int32 */, out_dtype="uint8", axis=1) /* ty=Tensor[(1, 3, 224, 224), uint8] */;
              %1 = nn.pad(%0, 54f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 3, 226, 226), uint8] */;
              %2 = qnn.quantize(%features.0.0_weight, meta[relay.Constant][0] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(32, 3, 3, 3), int8] */;
              %3 = qnn.conv2d(%1, %2, 54 /* ty=int32 */, 0 /* ty=int32 */, 0.0359743f /* ty=float32 */, meta[relay.Constant][0] /* ty=Tensor[(32), float32] */, strides=[2, 2], padding=[0, 0, 0, 0], channels=32, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %4 = qnn.quantize(%features.0.0_bias, meta[relay.Constant][1] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(32), int32] */;
              %5 = nn.bias_add(%3, %4) /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %6 = qnn.requantize(%5, meta[relay.Constant][2] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, 0.0132992f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %7 = clip(%6, a_min=0f, a_max=255f) /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %8 = cast(%7, dtype="uint8") /* ty=Tensor[(1, 32, 112, 112), uint8] */;
              %9 = nn.pad(%8, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 32, 114, 114), uint8] */;
              %10 = qnn.quantize(%features.1.conv.0.0_weight, meta[relay.Constant][3] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(32, 1, 3, 3), int8] */;
              %11 = qnn.conv2d(%9, %10, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0132992f /* ty=float32 */, meta[relay.Constant][3] /* ty=Tensor[(32), float32] */, padding=[0, 0, 0, 0], groups=32, channels=32, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %12 = qnn.quantize(%features.1.conv.0.0_bias, meta[relay.Constant][4] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(32), int32] */;
              %13 = nn.bias_add(%11, %12) /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %14 = qnn.requantize(%13, meta[relay.Constant][5] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, 0.0710117f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %15 = clip(%14, a_min=0f, a_max=255f) /* ty=Tensor[(1, 32, 112, 112), int32] */;
              %16 = cast(%15, dtype="uint8") /* ty=Tensor[(1, 32, 112, 112), uint8] */;
              %17 = qnn.quantize(%features.1.conv.1_weight, meta[relay.Constant][6] /* ty=Tensor[(16), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(16, 32, 1, 1), int8] */;
              %18 = qnn.conv2d(%16, %17, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0710117f /* ty=float32 */, meta[relay.Constant][6] /* ty=Tensor[(16), float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 16, 112, 112), int32] */;
              %19 = qnn.quantize(%features.1.conv.1_bias, meta[relay.Constant][7] /* ty=Tensor[(16), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(16), int32] */;
              %20 = nn.bias_add(%18, %19) /* ty=Tensor[(1, 16, 112, 112), int32] */;
              %21 = qnn.requantize(%20, meta[relay.Constant][8] /* ty=Tensor[(16), float32] */, 0 /* ty=int32 */, 0.0674597f /* ty=float32 */, 58 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 16, 112, 112), int32] */;
              %22 = clip(%21, a_min=0f, a_max=255f) /* ty=Tensor[(1, 16, 112, 112), int32] */;
              %23 = cast(%22, dtype="uint8") /* ty=Tensor[(1, 16, 112, 112), uint8] */;
              %24 = qnn.quantize(%features.2.conv.0.0_weight, meta[relay.Constant][9] /* ty=Tensor[(96), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(96, 16, 1, 1), int8] */;
              %25 = qnn.conv2d(%23, %24, 58 /* ty=int32 */, 0 /* ty=int32 */, 0.0674597f /* ty=float32 */, meta[relay.Constant][9] /* ty=Tensor[(96), float32] */, padding=[0, 0, 0, 0], channels=96, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 96, 112, 112), int32] */;
              %26 = qnn.quantize(%features.2.conv.0.0_bias, meta[relay.Constant][10] /* ty=Tensor[(96), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(96), int32] */;
              %27 = nn.bias_add(%25, %26) /* ty=Tensor[(1, 96, 112, 112), int32] */;
              %28 = qnn.requantize(%27, meta[relay.Constant][11] /* ty=Tensor[(96), float32] */, 0 /* ty=int32 */, 0.0274019f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 96, 112, 112), int32] */;
              %29 = clip(%28, a_min=0f, a_max=255f) /* ty=Tensor[(1, 96, 112, 112), int32] */;
              %30 = cast(%29, dtype="uint8") /* ty=Tensor[(1, 96, 112, 112), uint8] */;
              %31 = nn.pad(%30, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 96, 114, 114), uint8] */;
              %32 = qnn.quantize(%features.2.conv.1.0_weight, meta[relay.Constant][12] /* ty=Tensor[(96), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(96, 1, 3, 3), int8] */;
              %33 = qnn.conv2d(%31, %32, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0274019f /* ty=float32 */, meta[relay.Constant][12] /* ty=Tensor[(96), float32] */, strides=[2, 2], padding=[0, 0, 0, 0], groups=96, channels=96, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 96, 56, 56), int32] */;
              %34 = qnn.quantize(%features.2.conv.1.0_bias, meta[relay.Constant][13] /* ty=Tensor[(96), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(96), int32] */;
              %35 = nn.bias_add(%33, %34) /* ty=Tensor[(1, 96, 56, 56), int32] */;
              %36 = qnn.requantize(%35, meta[relay.Constant][14] /* ty=Tensor[(96), float32] */, 0 /* ty=int32 */, 0.018431f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 96, 56, 56), int32] */;
              %37 = clip(%36, a_min=0f, a_max=255f) /* ty=Tensor[(1, 96, 56, 56), int32] */;
              %38 = cast(%37, dtype="uint8") /* ty=Tensor[(1, 96, 56, 56), uint8] */;
              %39 = qnn.quantize(%features.2.conv.2_weight, meta[relay.Constant][15] /* ty=Tensor[(24), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(24, 96, 1, 1), int8] */;
              %40 = qnn.conv2d(%38, %39, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.018431f /* ty=float32 */, meta[relay.Constant][15] /* ty=Tensor[(24), float32] */, padding=[0, 0, 0, 0], channels=24, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %41 = qnn.quantize(%features.2.conv.2_bias, meta[relay.Constant][16] /* ty=Tensor[(24), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(24), int32] */;
              %42 = nn.bias_add(%40, %41) /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %43 = qnn.requantize(%42, meta[relay.Constant][17] /* ty=Tensor[(24), float32] */, 0 /* ty=int32 */, 0.044329f /* ty=float32 */, 59 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %44 = clip(%43, a_min=0f, a_max=255f) /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %45 = cast(%44, dtype="uint8") /* ty=Tensor[(1, 24, 56, 56), uint8] */;
              %46 = qnn.quantize(%features.3.conv.0.0_weight, meta[relay.Constant][18] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(144, 24, 1, 1), int8] */;
              %47 = qnn.conv2d(%45, %46, 59 /* ty=int32 */, 0 /* ty=int32 */, 0.044329f /* ty=float32 */, meta[relay.Constant][18] /* ty=Tensor[(144), float32] */, padding=[0, 0, 0, 0], channels=144, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %48 = qnn.quantize(%features.3.conv.0.0_bias, meta[relay.Constant][19] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(144), int32] */;
              %49 = nn.bias_add(%47, %48) /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %50 = qnn.requantize(%49, meta[relay.Constant][20] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, 0.00914225f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %51 = clip(%50, a_min=0f, a_max=255f) /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %52 = cast(%51, dtype="uint8") /* ty=Tensor[(1, 144, 56, 56), uint8] */;
              %53 = nn.pad(%52, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 144, 58, 58), uint8] */;
              %54 = qnn.quantize(%features.3.conv.1.0_weight, meta[relay.Constant][21] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(144, 1, 3, 3), int8] */;
              %55 = qnn.conv2d(%53, %54, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.00914225f /* ty=float32 */, meta[relay.Constant][21] /* ty=Tensor[(144), float32] */, padding=[0, 0, 0, 0], groups=144, channels=144, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %56 = qnn.quantize(%features.3.conv.1.0_bias, meta[relay.Constant][22] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(144), int32] */;
              %57 = nn.bias_add(%55, %56) /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %58 = qnn.requantize(%57, meta[relay.Constant][23] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, 0.0202167f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %59 = clip(%58, a_min=0f, a_max=255f) /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %60 = cast(%59, dtype="uint8") /* ty=Tensor[(1, 144, 56, 56), uint8] */;
              %61 = qnn.quantize(%features.3.conv.2_weight, meta[relay.Constant][24] /* ty=Tensor[(24), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(24, 144, 1, 1), int8] */;
              %62 = qnn.conv2d(%60, %61, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0202167f /* ty=float32 */, meta[relay.Constant][24] /* ty=Tensor[(24), float32] */, padding=[0, 0, 0, 0], channels=24, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %63 = qnn.quantize(%features.3.conv.2_bias, meta[relay.Constant][25] /* ty=Tensor[(24), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(24), int32] */;
              %64 = nn.bias_add(%62, %63) /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %65 = qnn.requantize(%64, meta[relay.Constant][26] /* ty=Tensor[(24), float32] */, 0 /* ty=int32 */, 0.0576775f /* ty=float32 */, 58 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %66 = clip(%65, a_min=0f, a_max=255f) /* ty=Tensor[(1, 24, 56, 56), int32] */;
              %67 = cast(%66, dtype="uint8") /* ty=Tensor[(1, 24, 56, 56), uint8] */;
              %68 = @tvmgen_default_arm_compute_lib_main_0(%45, %67) /* ty=Tensor[(1, 24, 56, 56), uint8] */;
              %69 = qnn.quantize(%features.4.conv.0.0_weight, meta[relay.Constant][27] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(144, 24, 1, 1), int8] */;
              %70 = qnn.conv2d(%68, %69, 61 /* ty=int32 */, 0 /* ty=int32 */, 0.0764764f /* ty=float32 */, meta[relay.Constant][27] /* ty=Tensor[(144), float32] */, padding=[0, 0, 0, 0], channels=144, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %71 = qnn.quantize(%features.4.conv.0.0_bias, meta[relay.Constant][28] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(144), int32] */;
              %72 = nn.bias_add(%70, %71) /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %73 = qnn.requantize(%72, meta[relay.Constant][29] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, 0.0126277f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %74 = clip(%73, a_min=0f, a_max=255f) /* ty=Tensor[(1, 144, 56, 56), int32] */;
              %75 = cast(%74, dtype="uint8") /* ty=Tensor[(1, 144, 56, 56), uint8] */;
              %76 = nn.pad(%75, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 144, 58, 58), uint8] */;
              %77 = qnn.quantize(%features.4.conv.1.0_weight, meta[relay.Constant][30] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(144, 1, 3, 3), int8] */;
              %78 = qnn.conv2d(%76, %77, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0126277f /* ty=float32 */, meta[relay.Constant][30] /* ty=Tensor[(144), float32] */, strides=[2, 2], padding=[0, 0, 0, 0], groups=144, channels=144, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 144, 28, 28), int32] */;
              %79 = qnn.quantize(%features.4.conv.1.0_bias, meta[relay.Constant][31] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(144), int32] */;
              %80 = nn.bias_add(%78, %79) /* ty=Tensor[(1, 144, 28, 28), int32] */;
              %81 = qnn.requantize(%80, meta[relay.Constant][32] /* ty=Tensor[(144), float32] */, 0 /* ty=int32 */, 0.0223427f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 144, 28, 28), int32] */;
              %82 = clip(%81, a_min=0f, a_max=255f) /* ty=Tensor[(1, 144, 28, 28), int32] */;
              %83 = cast(%82, dtype="uint8") /* ty=Tensor[(1, 144, 28, 28), uint8] */;
              %84 = qnn.quantize(%features.4.conv.2_weight, meta[relay.Constant][33] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(32, 144, 1, 1), int8] */;
              %85 = qnn.conv2d(%83, %84, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0223427f /* ty=float32 */, meta[relay.Constant][33] /* ty=Tensor[(32), float32] */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %86 = qnn.quantize(%features.4.conv.2_bias, meta[relay.Constant][34] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(32), int32] */;
              %87 = nn.bias_add(%85, %86) /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %88 = qnn.requantize(%87, meta[relay.Constant][35] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, 0.0386601f /* ty=float32 */, 72 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %89 = clip(%88, a_min=0f, a_max=255f) /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %90 = cast(%89, dtype="uint8") /* ty=Tensor[(1, 32, 28, 28), uint8] */;
              %91 = qnn.quantize(%features.5.conv.0.0_weight, meta[relay.Constant][36] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(192, 32, 1, 1), int8] */;
              %92 = qnn.conv2d(%90, %91, 72 /* ty=int32 */, 0 /* ty=int32 */, 0.0386601f /* ty=float32 */, meta[relay.Constant][36] /* ty=Tensor[(192), float32] */, padding=[0, 0, 0, 0], channels=192, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %93 = qnn.quantize(%features.5.conv.0.0_bias, meta[relay.Constant][37] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(192), int32] */;
              %94 = nn.bias_add(%92, %93) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %95 = qnn.requantize(%94, meta[relay.Constant][38] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, 0.00836057f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %96 = clip(%95, a_min=0f, a_max=255f) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %97 = cast(%96, dtype="uint8") /* ty=Tensor[(1, 192, 28, 28), uint8] */;
              %98 = nn.pad(%97, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 192, 30, 30), uint8] */;
              %99 = qnn.quantize(%features.5.conv.1.0_weight, meta[relay.Constant][39] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(192, 1, 3, 3), int8] */;
              %100 = qnn.conv2d(%98, %99, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.00836057f /* ty=float32 */, meta[relay.Constant][39] /* ty=Tensor[(192), float32] */, padding=[0, 0, 0, 0], groups=192, channels=192, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %101 = qnn.quantize(%features.5.conv.1.0_bias, meta[relay.Constant][40] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(192), int32] */;
              %102 = nn.bias_add(%100, %101) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %103 = qnn.requantize(%102, meta[relay.Constant][41] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, 0.0112039f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %104 = clip(%103, a_min=0f, a_max=255f) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %105 = cast(%104, dtype="uint8") /* ty=Tensor[(1, 192, 28, 28), uint8] */;
              %106 = qnn.quantize(%features.5.conv.2_weight, meta[relay.Constant][42] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(32, 192, 1, 1), int8] */;
              %107 = qnn.conv2d(%105, %106, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0112039f /* ty=float32 */, meta[relay.Constant][42] /* ty=Tensor[(32), float32] */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %108 = qnn.quantize(%features.5.conv.2_bias, meta[relay.Constant][43] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(32), int32] */;
              %109 = nn.bias_add(%107, %108) /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %110 = qnn.requantize(%109, meta[relay.Constant][44] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, 0.0291111f /* ty=float32 */, 69 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %111 = clip(%110, a_min=0f, a_max=255f) /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %112 = cast(%111, dtype="uint8") /* ty=Tensor[(1, 32, 28, 28), uint8] */;
              %113 = @tvmgen_default_arm_compute_lib_main_8(%90, %112) /* ty=Tensor[(1, 32, 28, 28), uint8] */;
              %114 = qnn.quantize(%features.6.conv.0.0_weight, meta[relay.Constant][45] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(192, 32, 1, 1), int8] */;
              %115 = qnn.conv2d(%113, %114, 74 /* ty=int32 */, 0 /* ty=int32 */, 0.049797f /* ty=float32 */, meta[relay.Constant][45] /* ty=Tensor[(192), float32] */, padding=[0, 0, 0, 0], channels=192, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %116 = qnn.quantize(%features.6.conv.0.0_bias, meta[relay.Constant][46] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(192), int32] */;
              %117 = nn.bias_add(%115, %116) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %118 = qnn.requantize(%117, meta[relay.Constant][47] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, 0.0105545f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %119 = clip(%118, a_min=0f, a_max=255f) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %120 = cast(%119, dtype="uint8") /* ty=Tensor[(1, 192, 28, 28), uint8] */;
              %121 = nn.pad(%120, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 192, 30, 30), uint8] */;
              %122 = qnn.quantize(%features.6.conv.1.0_weight, meta[relay.Constant][48] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(192, 1, 3, 3), int8] */;
              %123 = qnn.conv2d(%121, %122, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0105545f /* ty=float32 */, meta[relay.Constant][48] /* ty=Tensor[(192), float32] */, padding=[0, 0, 0, 0], groups=192, channels=192, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %124 = qnn.quantize(%features.6.conv.1.0_bias, meta[relay.Constant][49] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(192), int32] */;
              %125 = nn.bias_add(%123, %124) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %126 = qnn.requantize(%125, meta[relay.Constant][50] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, 0.0122615f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %127 = clip(%126, a_min=0f, a_max=255f) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %128 = cast(%127, dtype="uint8") /* ty=Tensor[(1, 192, 28, 28), uint8] */;
              %129 = qnn.quantize(%features.6.conv.2_weight, meta[relay.Constant][51] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(32, 192, 1, 1), int8] */;
              %130 = qnn.conv2d(%128, %129, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0122615f /* ty=float32 */, meta[relay.Constant][51] /* ty=Tensor[(32), float32] */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %131 = qnn.quantize(%features.6.conv.2_bias, meta[relay.Constant][52] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(32), int32] */;
              %132 = nn.bias_add(%130, %131) /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %133 = qnn.requantize(%132, meta[relay.Constant][53] /* ty=Tensor[(32), float32] */, 0 /* ty=int32 */, 0.0446687f /* ty=float32 */, 62 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %134 = clip(%133, a_min=0f, a_max=255f) /* ty=Tensor[(1, 32, 28, 28), int32] */;
              %135 = cast(%134, dtype="uint8") /* ty=Tensor[(1, 32, 28, 28), uint8] */;
              %136 = @tvmgen_default_arm_compute_lib_main_16(%113, %135) /* ty=Tensor[(1, 32, 28, 28), uint8] */;
              %137 = qnn.quantize(%features.7.conv.0.0_weight, meta[relay.Constant][54] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(192, 32, 1, 1), int8] */;
              %138 = qnn.conv2d(%136, %137, 66 /* ty=int32 */, 0 /* ty=int32 */, 0.0765793f /* ty=float32 */, meta[relay.Constant][54] /* ty=Tensor[(192), float32] */, padding=[0, 0, 0, 0], channels=192, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %139 = qnn.quantize(%features.7.conv.0.0_bias, meta[relay.Constant][55] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(192), int32] */;
              %140 = nn.bias_add(%138, %139) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %141 = qnn.requantize(%140, meta[relay.Constant][56] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, 0.0156126f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %142 = clip(%141, a_min=0f, a_max=255f) /* ty=Tensor[(1, 192, 28, 28), int32] */;
              %143 = cast(%142, dtype="uint8") /* ty=Tensor[(1, 192, 28, 28), uint8] */;
              %144 = nn.pad(%143, 0f /* ty=float32 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 192, 30, 30), uint8] */;
              %145 = qnn.quantize(%features.7.conv.1.0_weight, meta[relay.Constant][57] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int8", axis=0) /* ty=Tensor[(192, 1, 3, 3), int8] */;
              %146 = qnn.conv2d(%144, %145, 0 /* ty=int32 */, 0 /* ty=int32 */, 0.0156126f /* ty=float32 */, meta[relay.Constant][57] /* ty=Tensor[(192), float32] */, strides=[2, 2], padding=[0, 0, 0, 0], groups=192, channels=192, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 192, 14, 14), int32] */;
              %147 = qnn.quantize(%features.7.conv.1.0_bias, meta[relay.Constant][58] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, out_dtype="int32", axis=0) /* ty=Tensor[(192), int32] */;
              %148 = nn.bias_add(%146, %147) /* ty=Tensor[(1, 192, 14, 14), int32] */;
              %149 = qnn.requantize(%148, meta[relay.Constant][59] /* ty=Tensor[(192), float32] */, 0 /* ty=int32 */, 0.0258851f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 192, 14, 14), int32] */;}
            def @tvmgen_default_arm_compute_lib_main_0(%arm_compute_lib_0_i0: Tensor[(1, 24, 56, 56), uint8], %arm_compute_lib_0_i1: Tensor[(1, 24, 56, 56), uint8], Inline=1, Compiler="arm_compute_lib", global_symbol="tvmgen_default_arm_compute_lib_main_0", Primitive=1) -> Tensor[(1, 24, 56, 56), uint8] {
              qnn.add(%arm_compute_lib_0_i0, %arm_compute_lib_0_i1, 0.044329f /* ty=float32 */, 59 /* ty=int32 */, 0.0576775f /* ty=float32 */, 58 /* ty=int32 */, 0.0764764f /* ty=float32 */, 61 /* ty=int32 */) /* ty=Tensor[(1, 24, 56, 56), uint8] */
            }
            def @tvmgen_default_arm_compute_lib_main_16(%arm_compute_lib_16_i0: Tensor[(1, 32, 28, 28), uint8], %arm_compute_lib_16_i1: Tensor[(1, 32, 28, 28), uint8], Inline=1, Compiler="arm_compute_lib", global_symbol="tvmgen_default_arm_compute_lib_main_16", Primitive=1) -> Tensor[(1, 32, 28, 28), uint8] {
              qnn.add(%arm_compute_lib_16_i0, %arm_compute_lib_16_i1, 0.049797f /* ty=float32 */, 74 /* ty=int32 */, 0.0446687f /* ty=float32 */, 62 /* ty=int32 */, 0.0765793f /* ty=float32 */, 66 /* ty=int32 */) /* ty=Tensor[(1, 32, 28, 28), uint8] */
            }
            def @tvmgen_default_arm_compute_lib_main_24(%arm_compute_lib_24_i0: Tensor[(1, 64, 14, 14), uint8], %arm_compute_lib_24_i1: Tensor[(1, 64, 14, 14), uint8], Inline=1, Compiler="arm_compute_lib", global_symbol="tvmgen_default_arm_compute_lib_main_24", Primitive=1) -> Tensor[(1, 64, 14, 14), uint8] {
              qnn.add(%arm_compute_lib_24_i0, %arm_compute_lib_24_i1, 0.0366653f /* ty=float32 */, 68 /* ty=int32 */, 0.025833f /* ty=float32 */, 69 /* ty=int32 */, 0.0461001f /* ty=float32 */, 70 /* ty=int32 */) /* ty=Tensor[(1, 64, 14, 14), uint8] */
            }
            def @tvmgen_default_arm_compute_lib_main_32(%arm_compute_lib_32_i0: Tensor[(1, 64, 14, 14), uint8], %arm_compute_lib_32_i1: Tensor[(1, 64, 14, 14), uint8], Inline=1, Compiler="arm_compute_lib", global_symbol="tvmgen_default_arm_compute_lib_main_32", Primitive=1) -> Tensor[(1, 64, 14, 14), uint8] {
              qnn.add(%arm_compute_lib_32_i0, %arm_compute_lib_32_i1, 0.0461001f /* ty=float32 */, 70 /* ty=int32 */, 0.0231473f /* ty=float32 */, 68 /* ty=int32 */, 0.0460796f /* ty=float32 */, 67 /* ty=int32 */) /* ty=Tensor[(1, 64, 14, 14), uint8] */
            }

PyTorch support multithread for sure, but if I remember correctly, QNNPACK library, used by PyTorch for quantized models, was not being compiled with multithreading support for ARM64 linux at that time. That may have changed now if you are seeing better PyTorch performance than TVM.

Actually I documented the performance benchmark results on popular imagenet models running on rasp 4 here https://github.com/Edgecortix-Inc/pytorch_quantization/blob/rasp4/tvm_qnn_evaluation/perf_bench_rasp4.py#L146-L169. As you can see, PyTorch results suggest it is running on a single thread, and I confirmed that by looking at their source code (but forgot which file).

By the way, does your model happen to be mobilenet v2? The numbers you mentioned roughly matches the result I got (36 ms for TVM, and 88 ms for PyTorch single thread).

Thanks @i02fesea for sending the Relay over. It seems as though the graph is in NCHW format which the ACL integration does not support, hence only a few qnn.add operations being offloaded. It you would like to try running the graph with ACL offloading I would suggest looking at the ConvertLayout pass to convert the graph to NHWC format

Thanks Masahi,

Yes, I am running mobilenet v2 quantized model. It is interesting that looking in your results, pytorch seems to have a better execution time than TVM. I am getting 25 ms with 4 cores and 84 ms with 1 core. Is it possible to get better results using TVM and ACL?

These are your results: 4T: 36.034155 ms 1T: 101.058245 ms Torch: 88.816326 ms

Thanks lhutton1. Yes, I will modify the input and I will get back to you with the results