[Quantization] Status of int8 input -> int8 output operators (with 32 bit accumulator)

I’m running through some quantization testing with ResNet18 on cpu, ingesting float32 models and then quantizing. I see that int8 -> int32 operators produce about the same output as the float32 model. However, when I configure to use int8 -> int8 output from operators, the network outputs 0.001 for every element of the final output vector. Has int8 output from operators been tested yet?

I was testing int8 output for dense layers on a commit from around mid january and had the same kind of junk output. In both cases (resnet on latest master branch, and the dense on a previous commit) the output is (1 / num_output_elements) for each element which is interesting.

Output

CPU (float32 out): 135 0.0014134465
CPU (int32 out): 135 0.0011491627
CPU (int8 out): 0 0.001

Test script

import numpy as np
import tvm
from tvm import relay
from tvm.relay import testing

from io import BytesIO
import requests
from matplotlib import pyplot as plt
from PIL import Image

def process_image(image):
    image = np.array(image) - np.array([123., 117., 104.])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]
    return tvm.nd.array(image.astype("float32"))


image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).resize((224, 224))
image = process_image(image)

def get_calibration_dataset(input_name):
    dataset = []
    for i in range(1):
        data = np.random.uniform(size=(1, 3, 224, 224))
        dataset.append({input_name: data})
    return dataset


def get_cpu_runtime(mod, params):
    func = mod['main']

    target = 'llvm'
    ctx = tvm.context(target)

    with relay.build_config(opt_level=3):
        graph, lib, params = relay.build_module.build(func, target=target, params=params)
    runtime =  tvm.contrib.graph_runtime.create(graph, lib, ctx)
    runtime.set_input(**params)

    return runtime


def quantize_calibrate(relay_ir, params, dataset, out_bits=32):
    quantization_config = {
        "nbit_input": 8,
        "nbit_weight": 8,
        "nbit_activation": out_bits,
        "dtype_input": "int8",
        "dtype_weight": "int8",
        "dtype_activation": "int"+str(out_bits),
        "calibrate_mode": "kl_divergence",
        "global_scale": 8.0,
        "weight_scale": "power2",
        "skip_conv_layers": [],
        "do_simulation": False,
        "round_for_shift": True,
        "debug_enabled_ops": None,
        "rounding": "UPWARD"
    }

    print("Quantize calibration....")
    with relay.quantize.qconfig(**quantization_config):
        relay_ir = relay.quantize.quantize(relay_ir, params, dataset=dataset)

    return relay_ir


def run_inference(runtime, batch):
    runtime.set_input(**batch)
    runtime.run()
    return runtime.get_output(0).asnumpy()


def test_cpu_quantization():
    # Get cpu runtime
    mod, params = testing.resnet.get_workload(num_layers=18)
    batch = get_calibration_dataset("data")
    cpu_runtime = get_cpu_runtime(mod, params)
    
    # Get quantized cpu runtime (int32 output dtype)
    mod, params = testing.resnet.get_workload(num_layers=18)
    batch = get_calibration_dataset("data")
    mod = quantize_calibrate(mod, params, batch)
    cpu_runtime_q32 = get_cpu_runtime(mod, params)

    # Get quantized cpu runtime (int8 output dtype)
    mod, params = testing.resnet.get_workload(num_layers=18)
    batch = get_calibration_dataset("data")
    mod = quantize_calibrate(mod, params, batch, out_bits=8)
    cpu_runtime_q8 = get_cpu_runtime(mod, params)

    # Run inferences
    cpu_outputs = run_inference(cpu_runtime, {"data": image})
    cpu_outputs_q32 = run_inference(cpu_runtime_q32, {"data": image})
    cpu_outputs_q8 = run_inference(cpu_runtime_q8, {"data": image})

    print("CPU (float32 out):", np.argmax(cpu_outputs), cpu_outputs.max())
    print("CPU (int32 out):", np.argmax(cpu_outputs_q32), cpu_outputs_q32.max())
    print("CPU (int8 out):", np.argmax(cpu_outputs_q8), cpu_outputs_q8.max())


if __name__ == "__main__":
    test_cpu_quantization()

Trying to understand the question. If we have int8 --> int8 operator (like dense/conv), it will lead to bad accuracy, unless we perform int32 accumulation. So, are you expecting the accumulation to be int32?

Sorry yes, this expects a 32 bit accumulator.

@vinx13 Does autoQ assume int32 accumuation for dense and conv, even when we specify the outputs to be int8?

No, if you specify int8 out type, conv and dense will output int8 (although the accumulation is done in int32, it is then converted to int8)

If the accumulation is happening in int32, should we expect such a bad accuracy as pointed out by @adb. Maybe, no

@janimesh int32 accumulation only happens in the intermediate of conv or dense, their outputs are still int8. How the intermediate accumulation done is actually implementation dependent.

I think if you specify out_dtype of relay.nn.conv2d as ‘int8’, the accumulation will not happen in int32. So, if AutoQ puts out_dtype as int8 in this case, I would expect the results that adb is seeing.

1 Like

@janimesh @vinx13, thank you for the discussion. Would you be able to link me to where this specific accumulation is occurring?