Severe performance degradation with quantization

Hi there, I’m quantizing a resnet50 model using tvm, running on a aarch64 ubuntu 22.04 virtual machine. It seems like the performance of the quantized model has dropped significantly, with a top1 accuracy of only 0.55 for a batch of 64 samples, compared to 0.75 for the full precision model.

Also, the inference time for TVM compiled model is consistent, about 0.1s per image, while the inference time for TVM quantized model grows linearly and continuously with the iteration, from 0.1s per image to 0.23s per image (the validation script is still running and it is still growing).

The script I’m using is attached below. Is there anything wrong with the script? Thanks so much for the help!

import tvm
from tvm import relay

import numpy as np

from tvm.contrib.download import download_testdata

# PyTorch imports
import torch
import torchvision

import time
model_name = "resnet50"
model = getattr(torchvision.models, model_name)(pretrained=True)


model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()


input_name = "input0"
shape_list = [(input_name, (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)


# quantization
with relay.quantize.qconfig(calibrate_mode="global_scale",
                            global_scale=8.0,
                            nbit_activation=16,
                            dtype_activation="int16",
                            skip_conv_layers=[],
                            skip_dense_layer=False,
                            partition_conversions="enabled",
                            do_simulation=False):
    mod = relay.quantize.quantize(mod, params)



target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)


# from tvm.contrib import graph_executor

# dtype = "float32"
# m = graph_executor.GraphModule(lib["default"](dev))
# before_tvm_inference = time.time()
# # Set inputs
# m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
# # Execute
# m.run()
# # Get outputs
# tvm_output = m.get_output(0)
# after_tvm_inference = time.time()
# print("tvm_inference", after_tvm_inference - before_tvm_inference)


import mxnet as mx


calibration_rec = download_testdata(
    "http://data.mxnet.io.s3-website-us-west-1.amazonaws.com/data/val_256_q90.rec",
    "val_256_q90.rec",
)

batch_size = 64

def get_val_data(num_workers=4):
    mean_rgb = [123.68, 116.779, 103.939]
    std_rgb = [58.393, 57.12, 57.375]
    def batch_fn(batch):
        return batch.data[0].asnumpy(), batch.label[0].asnumpy()
    img_size = 299 if model_name == "inceptionv3" else 224
    val_data = mx.io.ImageRecordIter(
        path_imgrec=calibration_rec,
        preprocess_threads=num_workers,
        shuffle=False,
        batch_size=batch_size,
        resize=256,
        data_shape=(3, img_size, img_size),
        mean_r=mean_rgb[0],
        mean_g=mean_rgb[1],
        mean_b=mean_rgb[2],
        std_r=std_rgb[0],
        std_g=std_rgb[1],
        std_b=std_rgb[2],
    )
    return val_data, batch_fn

validation_samples = 64 * 50
val_data, batch_fn = get_val_data()
val_data.reset()



# for quantization
dtype = "float32"
quant_model = relay.create_executor("vm", mod, dev, target).evaluate()
sum_quant_inference = 0
top1_acc = 0
top5_acc = 0
batch_size = 64
for i, batch in enumerate(val_data):
    data, target = batch_fn(batch)
    for j in range(batch_size):
        before_quant_inference = time.time()
        quant_output = quant_model(data[j].reshape((1, 3, 224, 224)).astype(dtype))
        after_quant_inference = time.time()
        # print(after_quant_inference - before_quant_inference)
        sum_quant_inference += after_quant_inference - before_quant_inference

        quant_output = quant_output.asnumpy()
        top1 = np.argsort(quant_output, axis = 1)[:,-1:]
        top5 = np.argsort(quant_output, axis = 1)[:,-5:]
        if target[j] in top1:
            top1_acc += 1
        if target[j] in top5:
            top5_acc += 1
    print("batch quant_inference", sum_quant_inference/(i+1))
    print("batch quant_accuracy top1_acc", top1_acc/(i+1))
    print("batch quant_accuracy top5_acc", top5_acc/(i+1))

print("quant_inference", sum_quant_inference/validation_samples)
print("quant_accuracy top1_acc", top1_acc/validation_samples)
print("quant_accuracy top5_acc", top5_acc/validation_samples)

sum_torch_inference = 0
top1_acc = 0
top5_acc = 0
for i, batch in enumerate(val_data):
    # Convert input to PyTorch variable and get PyTorch result for comparison
    with torch.no_grad():
        data, target = batch_fn(batch)
        for j in range(batch_size):
            torch_data = torch.from_numpy(data[j])
            before_torch_inference = time.time()
            output = model(torch_data.reshape((1, 3, 224, 224)))
            after_torch_inference = time.time()
            sum_torch_inference += after_torch_inference - before_torch_inference
            
            output = output.numpy()
            top1 = np.argsort(output, axis = 1)[:,-1:]
            top5 = np.argsort(output, axis = 1)[:,-5:]
            if target[j] in top1:
                top1_acc += 1
            if target[j] in top5:
                top5_acc += 1
        print("batch torch_inference", sum_torch_inference/(i+1))
        print("batch torch_accuracy top1_acc", top1_acc/(i+1))
        print("batch torch_accuracy top5_acc", top5_acc/(i+1))
print("torch_inference", sum_torch_inference/validation_samples)
print("torch_accuracy top1_acc", top1_acc/validation_samples)
print("torch_accuracy top5_acc", top5_acc/validation_samples)

Also, this is the script that I use to try tvm compiler without quantization



import tvm
from tvm import relay

import numpy as np

from tvm.contrib.download import download_testdata

# PyTorch imports
import torch
import torchvision

import time
model_name = "resnet50"
model = getattr(torchvision.models, model_name)(pretrained=True)


model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()




input_name = "input0"
shape_list = [(input_name, (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)



target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)



import mxnet as mx


calibration_rec = download_testdata(
    "http://data.mxnet.io.s3-website-us-west-1.amazonaws.com/data/val_256_q90.rec",
    "val_256_q90.rec",
)

batch_size = 64

def get_val_data(num_workers=4):
    mean_rgb = [123.68, 116.779, 103.939]
    std_rgb = [58.393, 57.12, 57.375]
    def batch_fn(batch):
        return batch.data[0].asnumpy(), batch.label[0].asnumpy()
    img_size = 299 if model_name == "inceptionv3" else 224
    val_data = mx.io.ImageRecordIter(
        path_imgrec=calibration_rec,
        preprocess_threads=num_workers,
        shuffle=False,
        batch_size=batch_size,
        resize=256,
        data_shape=(3, img_size, img_size),
        mean_r=mean_rgb[0],
        mean_g=mean_rgb[1],
        mean_b=mean_rgb[2],
        std_r=std_rgb[0],
        std_g=std_rgb[1],
        std_b=std_rgb[2],
    )
    return val_data, batch_fn

validation_samples = 64 * 50
val_data, batch_fn = get_val_data()
val_data.reset()
from tvm.contrib import graph_executor

# for quantization
dtype = "float32"
# tvm_model = relay.create_executor("vm", mod, dev, target).evaluate()
tvm_model = graph_executor.GraphModule(lib["default"](dev))
sum_tvm_inference = 0
top1_acc = 0
top5_acc = 0
batch_size = 64
for i, batch in enumerate(val_data):
    data, target = batch_fn(batch)
    for j in range(batch_size):
        before_tvm_inference = time.time()
        tvm_model.set_input(input_name, tvm.nd.array(data[j].reshape((1, 3, 224, 224)).astype(dtype)))
        # Execute
        tvm_model.run()
        # Get outputs
        tvm_output = tvm_model.get_output(0)
        # tvm_output = tvm_model(data[j].reshape((1, 3, 224, 224)).astype(dtype))
        after_tvm_inference = time.time()
        # print(after_tvm_inference - before_tvm_inference)
        sum_tvm_inference += after_tvm_inference - before_tvm_inference

        tvm_output = tvm_output.asnumpy()
        top1 = np.argsort(tvm_output, axis = 1)[:,-1:]
        top5 = np.argsort(tvm_output, axis = 1)[:,-5:]
        if target[j] in top1:
            top1_acc += 1
        if target[j] in top5:
            top5_acc += 1
    print("batch tvm_inference", sum_tvm_inference/(i+1))
    print("batch tvm_accuracy top1_acc", top1_acc/(i+1))
    print("batch tvm_accuracy top5_acc", top5_acc/(i+1))
print("tvm_inference", sum_tvm_inference/validation_samples)
print("tvm_accuracy top1_acc", top1_acc/validation_samples)
print("tvm_accuracy top5_acc", top5_acc/validation_samples)

sum_torch_inference = 0
top1_acc = 0
top5_acc = 0
for i, batch in enumerate(val_data):
    # Convert input to PyTorch variable and get PyTorch result for comparison
    with torch.no_grad():
        data, target = batch_fn(batch)
        for j in range(batch_size):
            torch_data = torch.from_numpy(data[j])
            before_torch_inference = time.time()
            output = model(torch_data.reshape((1, 3, 224, 224)))
            after_torch_inference = time.time()
            # print(after_torch_inference - before_torch_inference)
            sum_torch_inference += after_torch_inference - before_torch_inference
            
            output = output.numpy()
            top1 = np.argsort(output, axis = 1)[:,-1:]
            top5 = np.argsort(output, axis = 1)[:,-5:]
            if target[j] in top1:
                top1_acc += 1
            if target[j] in top5:
                top5_acc += 1
        print("batch torch_inference", sum_torch_inference/(i+1))
        print("batch torch_accuracy top1_acc", top1_acc/(i+1))
        print("batch torch_accuracy top5_acc", top5_acc/(i+1))
print("torch_inference", sum_torch_inference/validation_samples)
print("torch_accuracy top1_acc", top1_acc/validation_samples)
print("torch_accuracy top5_acc", top5_acc/validation_samples)




Also, another interesting result is that the inference time of quant model grows continuously from 6 to 15, then drops continuously from 15 to 12, while the inference time of tvm model grows suddenly from 6 to 8 like a step function.