Hi @masahi, thanks for your PR to support pytorch dynamic quantization in TVM. However, I found some strange problems and I don’t know how to solve. Maybe you can have some advices. Thanks!
# Requirements
torch==1.6.0
transformers==4.7.0
tvm==latest
Problem1: linear+relu with dynamic quantization will cause error in the following.
I think the problem is that we will still try to add_input_quant_params_to_op_inputs
when using dynamic quantization. Therefore, I think the code here should be modified (no need to call qnn_torch.add_input_quant_params_to_op_inputs
).
Reproduce Code modified from here
import os
import numpy as np
import torch
torch.manual_seed(0)
from torch import nn
import tvm
import tvm.testing
from tvm import relay
def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="llvm", params=params)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
return runtime
def test_quantize_dynamic():
class LinearWrapper(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(in_dim, hidden_dim)
self.relu = nn.ReLU()
def forward(self, inp):
inp = self.linear(inp)
inp = self.relu(inp)
return inp
mod = LinearWrapper(16, 32).eval()
for qconfig in [
torch.quantization.per_channel_dynamic_qconfig,
torch.quantization.default_dynamic_qconfig,
]:
for ishape in [(16, 16)]:
qspec = {'': qconfig}
qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)
inp = torch.randn(size=ishape)
script_module = torch.jit.trace(qmod, inp).eval()
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, inp.shape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)
if __name__ == '__main__':
test_quantize_dynamic()
Error Message
Traceback (most recent call last):
File "test_tvm_quantize_relu.py", line 56, in <module>
test_quantize_dynamic()
File "test_tvm_quantize_relu.py", line 49, in test_quantize_dynamic
runtime = get_tvm_runtime(script_module, input_name, inp.shape)
File "test_tvm_quantize_relu.py", line 15, in get_tvm_runtime
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/pytorch.py", line 3308, in from_pytorch
qnn_torch.add_input_quant_params_to_op_inputs(graph)
File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 396, in add_input_quant_params_to_op_inputs
scale, zp = _get_quant_param_for_input(node.inputsAt(i))
File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 218, in _get_quant_param_for_input
return dfs(input_value.node())
File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 213, in dfs
return dfs(arg.node())
File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 213, in dfs
return dfs(arg.node())
File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 216, in dfs
assert False, "No producer for %s" % (str(current_node))
AssertionError: No producer for %self.1 : __torch__.LinearWrapper, %x : Float(16:16, 16:1) = prim::Param()
Problem2: dynamic quantization in tvm has output mismatch (try Linear and BERT).
Linear
I follow your test code using only linear and it can work fine (pytorch output == tvm output). However, when I add more linear layers the mismatch problem will go serious, showing in the following.
num Layers | Input Shape | Mismatched elements | Max absolute diff | Max relative diff |
---|---|---|---|---|
1 | 1,512,768 | no mismatch | ||
2 | 1,512,768 | 1440 / 393216 (0.366%) | 0.01528 | 7.04946 |
3 | 1,512,768 | 2807 / 393216 (0.714%) | 0.00777 | 116.80091 |
4 | 1,512,768 | 3636 / 393216 (0.925%) | 0.01217 | 129.867 |
5 | 1,512,768 | 5489 / 393216 (1.4%) | 0.00835 | 9147.358 |
6 | 1,512,768 | 5698 / 393216 (1.45%) | 0.01026 | 245.44225 |
7 | 1,512,768 | 4922 / 393216 (1.25%) | 0.00592 | 52.72281 |
8 | 1,512,768 | 300240 / 393216 (76.4%) | 0.00335 | 1107.7794 |
16 | 1,512,768 | 278363 / 393216 (70.8%) | 0.00137 | 879.3084 |
32 | 1,512,768 | 278059 / 393216 (70.7%) | 0.00126 | 6613.7915 |
64 | 1,512,768 | 287331 / 393216 (73.1%) | 0.0014 | 1599.8228 |
Reproduce code
import os
import numpy as np
import torch
torch.manual_seed(0)
from torch import nn
import tvm
import tvm.testing
from tvm import relay
def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="llvm", params=params)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
return runtime
def test_quantize_dynamic():
class LinearWrapper(nn.Module):
def __init__(self, in_dim, hidden_dim, num_layers):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(in_dim, hidden_dim) for _ in range(num_layers)])
def forward(self, inp):
for linear in self.linears:
inp = linear(inp)
return inp
mod = LinearWrapper(768, 768, 8).eval()
for qconfig in [
torch.quantization.per_channel_dynamic_qconfig,
torch.quantization.default_dynamic_qconfig,
]:
for ishape in [(1, 512, 768)]:
qspec = {'': qconfig}
qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)
inp = torch.randn(size=ishape)
script_module = torch.jit.trace(qmod, inp).eval()
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, inp.shape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)
if __name__ == '__main__':
test_quantize_dynamic()
BERT
And when I try to use pretrained BERT
with dynamic quantization in TVM. I found the output is mismatch really much. (mismatched elements = 99.8%) And I try to go deep to find which layer cause this problem. I dump each output showing as the following BERTLayer
. We can find the output from the FC1 of the first BERTLayer is already mismatched seriously. To know the reason deeply, I also try initialize BERT
will random weights rather than pretrained weights. And we can find pretrained weights is hard to match Pytorch quantized weights for TVM than random initialized weights. Also, when we have more BERTlayers
, the problem goes more serious.
I know you have tested the performance of BERT
on this repo. But I still wonder whether this mismatch problem may affect the performance?
BERTLayer
BERT with pretrained weights
Pretrained Weights | Shape | Mismatched elements | Max abs diff | Max relative diff |
---|---|---|---|---|
BertLayer-12 | 1,512,768 | 392624 / 393216 (99.8%) | 0.76333 | 45480.062 |
BertLayer-11 | 1,512,768 | 392602 / 393216 (99.8%) | 2.02033 | 18859.594 |
BertLayer-10 | 1,512,768 | 392641 / 393216 (99.9%) | 1.53778 | 2202526.2 |
BertLayer-8 | 1,512,768 | 392634 / 393216 (99.9%) | 1.45215 | 126715.2 |
BertLayer-4 | 1,512,768 | 392486 / 393216 (99.8%) | 0.66306 | 26701.11 |
BertLayer-2 | 1,512,768 | 391995 / 393216 (99.7%) | 0.38736 | 269670.72 |
BertLayer-1-LN2 | 1,512,768 | 389262 / 393216 (99%) | 0.09919 | 12717.782 |
BertLayer-1-FC3+Skip2 | 1,512,768 | 390410 / 393216 (99.3%) | 0.84573 | 13479.788 |
BertLayer-1-FC3 | 1,512,768 | 391103 / 393216 (99.5%) | 0.84648 | 50794.223 |
BertLayer-1-GELU | 1,512,3072 | 929712 / 1572864 (59.1%) | 0.18429 | 888.477 |
BertLayer-1-FC2 | 1,512,3072 | 1538435 / 1572864 (97.8%) | 0.25227 | 861.802 |
BertLayer-1-LN1 | 1,512,768 | 339587 / 393216 (86.4%) | 0.09439 | 1578.3201 |
BertLayer-1-FC1+Skip1 | 1,512,768 | 328933 / 393216 (83.7%) | 0.02212 | 659.6583 |
BertLayer-1-FC1 | 1,512,768 | 335263 / 393216 (85.3%) | 0.02212 | 8279.921 |
BertLayer-1-MHA | 1,512,768 | 12301 / 393216 (3.13%) | 0.0063 | 75.73743 |
BertEmbeddings | 1,512,768 | no mismatch |
BERT with random initialized weights
Random Weights | Shape | Mismatched elements | Max abs diff | Max relative diff |
---|---|---|---|---|
BertLayer-12 | 1,512,768 | 391783 / 393216 (99.6%) | 0.18707 | 16690.617 |
BertLayer-11 | 1,512,768 | 391756 / 393216 (99.6%) | 0.16773 | 42095.33 |
BertLayer-10 | 1,512,768 | 391633 / 393216 (99.6%) | 0.1696 | 9282.958 |
BertLayer-8 | 1,512,768 | 391395 / 393216 (99.5%) | 0.14026 | 33522.902 |
BertLayer-4 | 1,512,768 | 389969 / 393216 (99.2%) | 0.08623 | 26893.422 |
BertLayer-2 | 1,512,768 | 384717 / 393216 (97.8%) | 0.04765 | 628.9598 |
BertLayer-1-LN2 | 1,512,768 | 61468 / 393216 (15.6%) | 0.02406 | 91.54878 |
BertLayer-1-FC3+Skip2 | 1,512,768 | 183873 / 1572864 (11.7%) | 0.01628 | 151.31506 |
BertLayer-1-FC3 | 1,512,768 | 60318 / 393216 (15.3%) | 0.0254 | 618 |
BertLayer-1-GELU | 1,512,3072 | 183873 / 1572864 (11.7%) | 0.01628 | 151.31506 |
BertLayer-1-FC2 | 1,512,3072 | 213255 / 1572864 (13.6%) | 0.01601 | 152.00002 |
BertLayer-1-LN1 | 1,512,768 | 35804 / 393216 (9.11%) | 0.00191 | 20.36604 |
BertLayer-1-FC1+Skip1 | 1,512,768 | 35911 / 393216 (9.13%) | 0.00194 | 38.69728 |
BertLayer-1-FC1 | 1,512,768 | 43518 / 393216 (11.1%) | 0.00194 | 58.00007 |
BertLayer-1-MHA | 1,512,768 | no mismatch | ||
BertEmbeddings | 1,512,768 | no mismatch |
Reproduce Code
import os
import numpy as np
import torch
torch.manual_seed(0)
from torch import nn
import tvm
import tvm.testing
from tvm import relay
from transformers import BertModel, BertConfig
def get_tvm_runtime(script_module, input_name, input_shape):
input_shapes = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="llvm", params=params)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
return runtime
def test_quantize_dynamic():
mod = BertModel.from_pretrained('bert-base-uncased', torchscript=True).eval()
for qconfig in [
torch.quantization.per_channel_dynamic_qconfig,
torch.quantization.default_dynamic_qconfig
]:
for ishape in [(1, 512)]:
qspec = {'': qconfig}
qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)
inp = torch.randint(high=2000, size=ishape)
script_module = torch.jit.trace(qmod, inp).eval()
with torch.no_grad():
pt_result = script_module(inp)[0].numpy()
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, inp.shape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)
if __name__ == '__main__':
test_quantize_dynamic()