Slower Execution Times After 8 bit Quantization?

Hello,

I am currently attempting to benchmark the execution times of various neural networks in TVM on my computer. However, whenever I quantize a network to 8 bits, it runs significantly slower than when it is executed at full precision.

For example, when I quantize resnet-18 to 8 bits and tune, the mean execution time is ~160 ms. But when I do not quantize, the execution time is ~90 ms. I would expect the 8 bit model to be much faster, but this does not seem to be the case.

Has anyone encountered this issue before? Are 8 bit operations not optimized in TVM? I am new to TVM, so perhaps there is a simple oversight I am making.

I will provide reproducible code below. If anyone has any insights or can help, I would greatly appreciate it.

import os
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
import tvm.contrib.graph_runtime as runtime
from tvm.relay import testing
import tvm.relay.quantize as quantize

target = tvm.target.Target('llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon,+neoversen1,+i8mm')

network = "resnet-18"
batch_size = 1
dtype = "float32"
layout = "NCHW"
log_file = "%s.log" % network
graph_opt_sch_file = "%s_graph_opt.log" % network

input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)

mod, params = relay.testing.resnet.get_workload(
    num_layers=18, batch_size=batch_size, layout=layout
)

### The quantization in question
with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
    mod = relay.quantize.quantize(mod, params)

tasks = autotvm.task.extract_from_program(
    mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
)

tuning_option = {
    "log_filename": log_file,
    "tuner": "xgb",
    "n_trial": 100,
    "early_stopping": 800,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(timeout=10),
        runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150),
    ),
}

for i, task in enumerate(tasks):
    print("Tuning task %d/%d" % (i + 1, len(tasks)))
    tuner = XGBTuner(task)
    tuner.tune(
        n_trial=min(tuning_option["n_trial"], len(task.config_space)),
        early_stopping=tuning_option["early_stopping"],
        measure_option=tuning_option["measure_option"],
        callbacks=[
            autotvm.callback.progress_bar(tuning_option["n_trial"]),
            autotvm.callback.log_to_file(tuning_option["log_filename"]),
        ],
    )
with autotvm.apply_history_best(log_file):
    with tvm.transform.PassContext(opt_level=3):
        graph, lib, params = relay.build_module.build(
            mod, target=target, params=params
        )

ctx = tvm.device(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input("data", data_tvm)
module.set_input(**params)

print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=60)
prof_res = np.array(ftimer().results) * 1000  # (convert to milliseconds)
avg_inference_time = np.mean(prof_res)
std_dev = np.std(prof_res)
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (avg_inference_time, std_dev))

autotvm.record.pick_best(log_file, 'mean_latency')
with open('avg_inference_time.txt', 'w') as f:
    f.write(f'Average Inference Time: {avg_inference_time} ms with std dev {std_dev} ms\n')

hi @zpu , some related discussions: Quantized models are slower than float models on GPUs - Questions - Apache TVM Discuss

Currently, this should be a common problem. And the suboptimal result may come from several distinct aspects.

In the aspect of computational graph, quantized model may introduce many additional operators, for example, cast op.

There is indeed some optimizations that can be done (but is not done) for quantized model. Recently, I am working on some computational graph level optimizations, and hopefully will be upstreamed to the main branch within 1 month. For now, if you are interested, you can try to run your model in this version and see if there is significant improvement.

Another aspect is relay.quantize.quantize. The quantization in TVM may not be well optimized currently. You can find some proposals to improve relay.quantize.quantize in the forum. However, it seems some proposals are eventually implemented in the main branch. Personally, I would like to systematically improve relay.quantize.quantize in the future. For now, perhaps you can try to use another tool (perhaps tflite?) to do quantization, and then import the quantized model to TVM. Maybe this can help improve performance.

I have encountered the same problem. When infering on the CPU, using the avx instruction can accelerate the quantization model, and the inference speed of the quantization model is faster than that of the floating-point model. However, using tvm quantization and avx instructions can cause other problems.