Why quantize pytorch model get low accuracy

I quantize pytorch model resnet50 using tvm relay.quantize.quantize(code show as below),Can’t get the correct result in top5。but I cat get correct in top1 when predicting same pic by onnx model which quantized by tvm and convetred from pytorch using torch.onnx.export().

from torchvision.transforms import transforms
import tvm
from tvm import relay, autotvm
import mxnet as mx

import torch
from torch.utils.data.dataloader import DataLoader
import torchvision
import torchvision.models as models
import torchvision.transforms as T
import torchvision.datasets as dset
import os
from tvm import target
import numpy as np

from PIL import Image
import numpy as np


target = "cuda"
dev = tvm.device(target)
batch_size = 1
calibration_samples = 10
model_name = "resnet50"


def export_onnx(model, im, file, opset, train, dynamic, simplify):
    try:
        import onnx
        print(f'\nstarting export with onnx {onnx.__version__}...')
        # f = file.with_suffix('onnx')
        torch.onnx.export(model, im, file, verbose=False, opset_version=opset,
                          training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
                          do_constant_folding=not train,
                          input_names=['images'],
                          output_names=['output'],
                          dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,224,224)
                                        'output': {0: 'batch', 1: 'num_classes'}  # shape(1,10000)
                                        } if dynamic else None)
        
        # checks
        model_onnx = onnx.load(file)
        onnx.checker.check_model(model_onnx)
        
        if simplify:
            try:
                import onnxsim
                print(f'simplifying with onnx-simplifier {onnxsim.__version__}...')
                model_onnx, check = onnxsim.simplify(
                    model_onnx,
                    dynamic_input_shape=dynamic,
                    input_shapes={'images': list(im.shape)} if dynamic else None)
                assert check, 'assert check failed'
                onnx.save(model_onnx, file)
            except Exception as e:
                print(f'simplifier failure: {e}')
        print(f'export success, saved as {file}')
        # print(f"run --dynamic ONNX model inference with: 'python detect.py --weights {file}'")
    except Exception as e:
        print(f'export failure: {e}')


calibration_rec = '/workspace/dataset/val_256_q90.rec'
# calibration_rec = '/workspace/tvm_demo/quantized_demo/quantized_datasets'
def get_val_data(num_workers=4):
    mean_rgb = [123.68, 116.779, 103.939]
    std_rgb = [58.393, 57.12, 57.375]
    # mean_rgb = [0, 0, 0]
    # std_rgb = [1, 1, 1]
    def batch_fn(batch):
        return batch.data[0].asnumpy(), batch.label[0].asnumpy()

    img_size = 299 if model_name == "inceptionv3" else 224
    val_data = mx.io.ImageRecordIter(
        path_imgrec=calibration_rec,
        preprocess_threads=num_workers,
        shuffle=False,
        batch_size=batch_size,
        resize=256,
        data_shape=(3, img_size, img_size),
        mean_r=mean_rgb[0],
        mean_g=mean_rgb[1],
        mean_b=mean_rgb[2],
        std_r=std_rgb[0],
        std_g=std_rgb[1],
        std_b=std_rgb[2],
    )
    return val_data, batch_fn

def calibrate_dataset():
    val_data, batch_fn = get_val_data()
    val_data.reset()
    for i, batch in enumerate(val_data):
        if i * batch_size >= calibration_samples:
            break
        data, _ = batch_fn(batch)
        yield {"data": data}

def calibrate_one(img_root):
    # img_path = "/workspace/tvm_demo/quantized_demo/cat.png"
    for i, img_file in enumerate(os.listdir(img_root)):
        img_path = os.path.join(img_root, img_file)
        img = Image.open(img_path)
        img = img.resize((224, 224))
        img = transform_img(img)
        img = np.array(img, dtype="float32")
        if i >= 1:
            break 
        yield {'data':img}


# def calibrate_dataset():
#     # mean = 
#     # std = []
#     coco_datasets = dset.CocoDetection(root = "/workspace/dataset/coco/images/val2017",
#         annFile = '/workspace/dataset/coco/annotations/instances_val2017.json')
#     for i, (data, target) in enumerate(coco_datasets):
#         custom_transforms = T.Compose([
#             T.Resize((224, 224)),
#             T.ToTensor(),
#             T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#         ])
#         data = custom_transforms(data).numpy()

#         # data /= 255.0
#         data = np.expand_dims(data, axis=0)
#         data = np.array(data, dtype="float32")
#         # print(data.shape)
#         if i  >= calibration_samples:
#             break
#         yield {'data': data}

def get_model(model_name, get_onnx=False):
    '''
    Description: convert pytorch model to tvm.IRModule

    Parameters: 
	-------
    model_name: pytorch model name in torchvision
    Returns: 
	-------
    mod : tvm.IRModule
        The module that optimizations will be performed on.

    params : dict of str to tvm.runtime.NDArray
        Dict of converted parameters stored in tvm.runtime.ndarray format
    '''    
    model = getattr(models, model_name)(pretrained=True)
    model = model.eval()
    img = torch.zeros((1, 3, 224, 224))
    if get_onnx:
        onnx_model_path = f'/workspace/tvm_demo/quantized_demo/{model_name}.onnx'
        export_onnx(model, img, onnx_model_path, opset=10, dynamic=True, train=False, simplify=False)
    
    input_shape = [1, 3, 224, 224]
    input_data = torch.randn(input_shape)
    scripted_model = torch.jit.trace(model, input_data).eval()
    input_name = 'input:0'
    shape_list = [(input_name, input_shape)]
    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
    return mod, params, model


def quanize(mod, params, data_aware):
    '''
    Description: **********
    Parameters: 
	-------
    Returns: 
	-------
    '''    
    
    if data_aware:
        with relay.quantize.qconfig(calibrate_mode='percentile', weight_scale='max', skip_conv_layers=[0], skip_dense_layer=True, round_for_shift=True):
            mod = relay.quantize.quantize(mod, params, dataset=calibrate_dataset())
    else:
        with relay.quantize.qconfig(calibrate_mode='global_scale', global_scale=0.8):
            mod = relay.quantize.quantize(mod, params)
    return mod



def transform_img(image):
    # image = np.array(image)
    custom_transforms = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    data = custom_transforms(image).numpy()
    # data /= 255.0
    data = np.expand_dims(data, axis=0)
    return data

def inference_one(mod, torch_mod, img_path):
    img = Image.open(img_path)
    
    img = transform_img(img)
    img = np.array(img, dtype="float32")
    
    print('run inference quantized mod')
    executor = relay.create_executor('vm', mod, dev, target)
    prediction = executor.evaluate()(img)
    pre_sotrt = np.argsort(prediction.asnumpy()[0])[::-1]
    print(f"the top1-top5 5:{pre_sotrt[:5]}, the top1 is {pre_sotrt[0]}")

    print('run inference pytorch mod')
    img_tensor = torch.tensor(img)
    out = torch_mod(img_tensor)
    print(torch.topk(out, 5))
    # print(torch.argmax(out))


if __name__ == "__main__":

    # cap = dset.CocoDetection(root = "/workspace/dataset/coco/images/val2017",
    #                     annFile = '/workspace/dataset/coco/annotations/instances_val2017.json')
    # print(calibrate_dataset(cap))

    get_onnx = False
    # mod, params, torch_model = get_model(model_name)
    model = getattr(models, model_name)(pretrained=True)
    model = model.eval()
    img = torch.zeros((1, 3, 224, 224))
    if get_onnx:
        onnx_model_path = f'/workspace/tvm_demo/quantized_demo/{model_name}.onnx'
        export_onnx(model, img, onnx_model_path, opset=10, dynamic=True, train=False, simplify=False)
    
    input_shape = [1, 3, 224, 224]
    input_data = torch.randn(input_shape)
    scripted_model = torch.jit.trace(model, input_data).eval()
    input_name = 'input:0'
    shape_list = [(input_name, input_shape)]
    with autotvm.tophub.context(target):
        mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
        # print('before quanize:', mod)
        mod = quanize(mod, params, data_aware=True)
        inference_one(mod, model, '/workspace/tvm_demo/quantized_demo/cat.png')
1 Like