Hi,
I created a pytorch quantization model. After compiling with tvm, I did inference. The result was inconsistent with pytorch. The strange thing is that this phenomenon occurs sometimes.
my code:
import torch
from torch import nn
from torch.quantization import QuantStub, DeQuantStub, get_default_qat_qconfig, convert, prepare_qat
from tvm import relay
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_executor
class AdaptiveAvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.quant(x)
y = self.pool(x)
y = self.dequant(y)
return y
def fuse_model(self):
pass
fp32_input = torch.randn(1, 3, 128, 128)
model = AdaptiveAvgPool2d()
BACKEND = "qnnpack"
model.qconfig = get_default_qat_qconfig(BACKEND)
prepare_qat(model, inplace=True)
model.eval()
y = model(fp32_input)
model_int8 = convert(model, inplace=True)
script_module = torch.jit.trace(model, fp32_input).eval()
input_name = "input"
input_infos = [(input_name, ((1, 3, 128, 128), 'float32'))]
img_input = np.random.rand(1, 3, 128, 128).astype(np.float32)
pt_input = torch.from_numpy(img_input)
torch.backends.quantized.engine = 'qnnpack'
with torch.no_grad():
pt_result = script_module(pt_input)
mod, params = relay.frontend.from_pytorch(script_module, input_infos)
target = "llvm"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
module = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
module.set_input(input_name, img_input)
module.run()
print(pt_result[0].numpy().flatten())
print(module.get_output(0).asnumpy().flatten())
print("compare result: ", np.allclose(pt_result[0].numpy().flatten(), module.get_output(0).asnumpy().flatten(), atol=1e-05))
If you run the above code repeatedly, you will find that the comparison result is sometimes true and sometimes false. Why is this?
compare result: True
[0.48794654 0.48794654 0.48794654]
[0.48794654 0.48794654 0.48794654]
compare result: True
compare result: False
[0.5066974 0.5066974 0.5066974]
[0.47291756 0.47291756 0.47291756]
compare result: False