[Bert Quanzation] Quantiazation Bert Model Accuracy is not Correct

device: nvidia t4 cuda version 10.2

from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
bert_model_origin = bert_model_origin.to("cpu")
example_tensor = torch.randint(0, 100, (1, 256))
trace_model_fp32 = torch.jit.trace(bert_model_origin.to("cpu"), [example_tensor.to("cpu")])
shape_list = [(i.debugName().split('.')[0], i.type().sizes()) for i in list(trace_model_fp32.graph.inputs())[1:]]
mod_bert_fp32, params_bert_fp32 = tvm.relay.frontend.pytorch.from_pytorch(trace_model_fp32, shape_list)

def test_int8_bert_model_tvm(mod_bert, params_bert, dataset, infer_count=1000):
    target = tvm.target.cuda()
    dev = tvm.device(target.kind.name, 0)

    def calibrate_dataset(calibrate_num=100, input_name='input_ids'):
        for i in range(calibrate_num):
            record = dataset[i]
            index_tokens = record['index_tokens'].to("cpu")
            index_tokens_tvm = tvm.nd.array(index_tokens.numpy(), dev)
            yield {input_name: index_tokens_tvm}

    with relay.quantize.qconfig(calibrate_mode="kl_divergence", weight_scale="max"):
        mod = relay.quantize.quantize(mod_bert, params_bert, dataset=calibrate_dataset())
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target.kind.name, params=params_bert)
        lib.export_library(os.path.realpath("bert_tvm_int8_cuda.tar"))
    if infer_count != 0:
        tvm_inference(lib, dataset, target, infer_count=infer_count)

the final score got unbelievable 0.003 the fp32 up to 0.886 accuracy

@tqchen @Wheest @AndrewZhaoLuo

1 Like

First of all, quantization of BERT is a hard problem. In particular, a post-training quantization approach like yours probably won’t work.

Second, the quantization module in TVM is very limited and not validated on a wide variety of models. I don’t recommend using it.

Strongly agree with you. TVM does support quanzation not very well…

TF, torch and onnx all have own quntization algorithm and various api. We prefer TVM , as only need to maintain TVM quanzatization not all.

Post-training quantization approach we have tried, it works. For some reason, we also need TVM quantization works for bert.