[Quantization] Pytorch Dynamic Quantization

Hi @masahi, thanks for your PR to support pytorch dynamic quantization in TVM. However, I found some strange problems and I don’t know how to solve. Maybe you can have some advices. Thanks!

# Requirements
torch==1.6.0
transformers==4.7.0
tvm==latest

Problem1: linear+relu with dynamic quantization will cause error in the following.

I think the problem is that we will still try to add_input_quant_params_to_op_inputs when using dynamic quantization. Therefore, I think the code here should be modified (no need to call qnn_torch.add_input_quant_params_to_op_inputs).

Reproduce Code modified from here

import os
import numpy as np

import torch
torch.manual_seed(0)
from torch import nn

import tvm
import tvm.testing
from tvm import relay
   
def get_tvm_runtime(script_module, input_name, ishape):
    input_shapes = [(input_name, ishape)]
    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target="llvm", params=params)
    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
    return runtime

def test_quantize_dynamic():                    
    class LinearWrapper(nn.Module):
        def __init__(self, in_dim, hidden_dim):
            super().__init__()
            self.linear = nn.Linear(in_dim, hidden_dim)
            self.relu = nn.ReLU()
        def forward(self, inp):
            inp = self.linear(inp)
            inp = self.relu(inp)
            return inp

    mod = LinearWrapper(16, 32).eval()
    
    for qconfig in [
        torch.quantization.per_channel_dynamic_qconfig,
        torch.quantization.default_dynamic_qconfig,
    ]:
        for ishape in [(16, 16)]:
            qspec = {'': qconfig}
            qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)
                    
            inp = torch.randn(size=ishape)
            script_module = torch.jit.trace(qmod, inp).eval()

            with torch.no_grad():
                pt_result = script_module(inp.clone()).numpy()

            input_name = "input"
            runtime = get_tvm_runtime(script_module, input_name, inp.shape)
            runtime.set_input(input_name, inp.numpy().copy())
            runtime.run()
            tvm_result = runtime.get_output(0).asnumpy()
            tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)
            
if __name__ == '__main__':
    test_quantize_dynamic()

Error Message

Traceback (most recent call last):
  File "test_tvm_quantize_relu.py", line 56, in <module>
    test_quantize_dynamic()
  File "test_tvm_quantize_relu.py", line 49, in test_quantize_dynamic
    runtime = get_tvm_runtime(script_module, input_name, inp.shape)
  File "test_tvm_quantize_relu.py", line 15, in get_tvm_runtime
    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
  File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/pytorch.py", line 3308, in from_pytorch
    qnn_torch.add_input_quant_params_to_op_inputs(graph)
  File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 396, in add_input_quant_params_to_op_inputs
    scale, zp = _get_quant_param_for_input(node.inputsAt(i))
  File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 218, in _get_quant_param_for_input
    return dfs(input_value.node())
  File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 213, in dfs
    return dfs(arg.node())
  File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 213, in dfs
    return dfs(arg.node())
  File "/local/home/chengpi/tvm_exp/packages/tvm/python/tvm/relay/frontend/qnn_torch.py", line 216, in dfs
    assert False, "No producer for %s" % (str(current_node))
AssertionError: No producer for %self.1 : __torch__.LinearWrapper, %x : Float(16:16, 16:1) = prim::Param()

Problem2: dynamic quantization in tvm has output mismatch (try Linear and BERT).

Linear

I follow your test code using only linear and it can work fine (pytorch output == tvm output). However, when I add more linear layers the mismatch problem will go serious, showing in the following.

num Layers Input Shape Mismatched elements Max absolute diff Max relative diff
1 1,512,768 no mismatch
2 1,512,768 1440 / 393216 (0.366%) 0.01528 7.04946
3 1,512,768 2807 / 393216 (0.714%) 0.00777 116.80091
4 1,512,768 3636 / 393216 (0.925%) 0.01217 129.867
5 1,512,768 5489 / 393216 (1.4%) 0.00835 9147.358
6 1,512,768 5698 / 393216 (1.45%) 0.01026 245.44225
7 1,512,768 4922 / 393216 (1.25%) 0.00592 52.72281
8 1,512,768 300240 / 393216 (76.4%) 0.00335 1107.7794
16 1,512,768 278363 / 393216 (70.8%) 0.00137 879.3084
32 1,512,768 278059 / 393216 (70.7%) 0.00126 6613.7915
64 1,512,768 287331 / 393216 (73.1%) 0.0014 1599.8228

Reproduce code

import os
import numpy as np

import torch
torch.manual_seed(0)
from torch import nn

import tvm
import tvm.testing
from tvm import relay
    
def get_tvm_runtime(script_module, input_name, ishape):
    input_shapes = [(input_name, ishape)]
    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target="llvm", params=params)
    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
    return runtime

def test_quantize_dynamic():                    
    class LinearWrapper(nn.Module):
        def __init__(self, in_dim, hidden_dim, num_layers):
            super().__init__()
            self.linears = nn.ModuleList([nn.Linear(in_dim, hidden_dim) for _ in range(num_layers)])
        def forward(self, inp):
            for linear in self.linears:
                inp = linear(inp)
            return inp

    mod = LinearWrapper(768, 768, 8).eval()
    
    for qconfig in [
        torch.quantization.per_channel_dynamic_qconfig,
        torch.quantization.default_dynamic_qconfig,
    ]:
        for ishape in [(1, 512, 768)]:
            qspec = {'': qconfig}
            qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)
                    
            inp = torch.randn(size=ishape)
            script_module = torch.jit.trace(qmod, inp).eval()

            with torch.no_grad():
                pt_result = script_module(inp.clone()).numpy()

            input_name = "input"
            runtime = get_tvm_runtime(script_module, input_name, inp.shape)
            runtime.set_input(input_name, inp.numpy().copy())
            runtime.run()
            tvm_result = runtime.get_output(0).asnumpy()
            tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)
            
if __name__ == '__main__':
    test_quantize_dynamic()

BERT

And when I try to use pretrained BERT with dynamic quantization in TVM. I found the output is mismatch really much. (mismatched elements = 99.8%) And I try to go deep to find which layer cause this problem. I dump each output showing as the following BERTLayer. We can find the output from the FC1 of the first BERTLayer is already mismatched seriously. To know the reason deeply, I also try initialize BERT will random weights rather than pretrained weights. And we can find pretrained weights is hard to match Pytorch quantized weights for TVM than random initialized weights. Also, when we have more BERTlayers, the problem goes more serious.

I know you have tested the performance of BERT on this repo. But I still wonder whether this mismatch problem may affect the performance?

BERTLayer

BERT with pretrained weights

Pretrained Weights Shape Mismatched elements Max abs diff Max relative diff
BertLayer-12 1,512,768 392624 / 393216 (99.8%) 0.76333 45480.062
BertLayer-11 1,512,768 392602 / 393216 (99.8%) 2.02033 18859.594
BertLayer-10 1,512,768 392641 / 393216 (99.9%) 1.53778 2202526.2
BertLayer-8 1,512,768 392634 / 393216 (99.9%) 1.45215 126715.2
BertLayer-4 1,512,768 392486 / 393216 (99.8%) 0.66306 26701.11
BertLayer-2 1,512,768 391995 / 393216 (99.7%) 0.38736 269670.72
BertLayer-1-LN2 1,512,768 389262 / 393216 (99%) 0.09919 12717.782
BertLayer-1-FC3+Skip2 1,512,768 390410 / 393216 (99.3%) 0.84573 13479.788
BertLayer-1-FC3 1,512,768 391103 / 393216 (99.5%) 0.84648 50794.223
BertLayer-1-GELU 1,512,3072 929712 / 1572864 (59.1%) 0.18429 888.477
BertLayer-1-FC2 1,512,3072 1538435 / 1572864 (97.8%) 0.25227 861.802
BertLayer-1-LN1 1,512,768 339587 / 393216 (86.4%) 0.09439 1578.3201
BertLayer-1-FC1+Skip1 1,512,768 328933 / 393216 (83.7%) 0.02212 659.6583
BertLayer-1-FC1 1,512,768 335263 / 393216 (85.3%) 0.02212 8279.921
BertLayer-1-MHA 1,512,768 12301 / 393216 (3.13%) 0.0063 75.73743
BertEmbeddings 1,512,768 no mismatch

BERT with random initialized weights

Random Weights Shape Mismatched elements Max abs diff Max relative diff
BertLayer-12 1,512,768 391783 / 393216 (99.6%) 0.18707 16690.617
BertLayer-11 1,512,768 391756 / 393216 (99.6%) 0.16773 42095.33
BertLayer-10 1,512,768 391633 / 393216 (99.6%) 0.1696 9282.958
BertLayer-8 1,512,768 391395 / 393216 (99.5%) 0.14026 33522.902
BertLayer-4 1,512,768 389969 / 393216 (99.2%) 0.08623 26893.422
BertLayer-2 1,512,768 384717 / 393216 (97.8%) 0.04765 628.9598
BertLayer-1-LN2 1,512,768 61468 / 393216 (15.6%) 0.02406 91.54878
BertLayer-1-FC3+Skip2 1,512,768 183873 / 1572864 (11.7%) 0.01628 151.31506
BertLayer-1-FC3 1,512,768 60318 / 393216 (15.3%) 0.0254 618
BertLayer-1-GELU 1,512,3072 183873 / 1572864 (11.7%) 0.01628 151.31506
BertLayer-1-FC2 1,512,3072 213255 / 1572864 (13.6%) 0.01601 152.00002
BertLayer-1-LN1 1,512,768 35804 / 393216 (9.11%) 0.00191 20.36604
BertLayer-1-FC1+Skip1 1,512,768 35911 / 393216 (9.13%) 0.00194 38.69728
BertLayer-1-FC1 1,512,768 43518 / 393216 (11.1%) 0.00194 58.00007
BertLayer-1-MHA 1,512,768 no mismatch
BertEmbeddings 1,512,768 no mismatch

Reproduce Code

import os
import numpy as np

import torch
torch.manual_seed(0)
from torch import nn

import tvm
import tvm.testing
from tvm import relay

from transformers import BertModel, BertConfig

def get_tvm_runtime(script_module, input_name, input_shape):
    input_shapes = [(input_name, input_shape)]
    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target="llvm", params=params)
    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
    return runtime

def test_quantize_dynamic():
    mod = BertModel.from_pretrained('bert-base-uncased', torchscript=True).eval()
    
    for qconfig in [
        torch.quantization.per_channel_dynamic_qconfig,
        torch.quantization.default_dynamic_qconfig
    ]:
        for ishape in [(1, 512)]:
            qspec = {'': qconfig}
            qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8)

            inp =  torch.randint(high=2000, size=ishape)
            script_module = torch.jit.trace(qmod, inp).eval()

            with torch.no_grad():
                pt_result = script_module(inp)[0].numpy()

            input_name = "input"
            runtime = get_tvm_runtime(script_module, input_name, inp.shape)
            runtime.set_input(input_name, inp.numpy().copy())
            runtime.run()
            tvm_result = runtime.get_output(0).asnumpy()
            tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)

if __name__ == '__main__':
    test_quantize_dynamic()

Yes, I also found that outputs from dynamic-quantized Linear are off for larger inputs. I have no idea what could be the issue, would be great if you can dig a bit more.

Okay! I’ll try to dig more to find the reasons. And for problem1, do I need to modified the code here? Because I think when using dynamic quantization, we don’t need to add input scale and zero_point to the operation,

That makes sense. We need to decide when to use add_input_quant_params_to_op_inputs based on ops in a graph.