Dense layers do not quantize from float32 to int8 in relay

I’m trying to quantize a simple float32 model that has just a few dense layers. I want all nn.dense ops and their inputs/outputs to quantize down to int8 but I’m seeing no change from float32 to int8 in the Relay IR.

nn.conv2d layers seem to quantize fine, I’m just having issues with nn.dense ops. I’ve tried running the deploy_vision_on_vta.py example and the last dense layer also does not quantize.

1 Like

@thierry I remember that we disabled dense layer quantization for some reasons in vta, what’s the status now?

@adb do you mind pointing us to a gist of your code?

Also I now recall that the dense quantization was manually turned off: https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/quantize/_annotate.py#L179-L180

I can work on a PR that lets us enable/disable dense quantization in the quantization config (and act upon the TODO that was left there). I can add your example to our test cases if you share your code.

I’m able to reproduce with this example

# MNIST mlp code from https://keras.io/examples/mnist_mlp/

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop

batch_size = 128
num_classes = 10
epochs = 1

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dense(512, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))

model.summary()

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))



import tvm
from tvm import relay

def test_quantize_keras(keras_model, shape_dict, target):
    with tvm.target.create(target):
        # Load model
        mod, params = relay.frontend.from_keras(keras_model, shape=shape_dict)
        print("--------- Initial Relay IR ---------")
        print(mod)

        # Quantize
        with relay.build_config(opt_level=3):
            with relay.quantize.qconfig(global_scale=8.0):
                mod = relay.quantize.quantize(mod, params=params)
        print("--------- Quantized Relay IR ---------")
        print(mod)

        # Compile
        print("--------- Begin relay.build ---------")
        with relay.transform.build_config(opt_level=3):
            graph_json, mod, params = relay.build_module.build(mod, target)

shape_dict = {'dense_1_input': (1, 784)}
test_quantize_keras(model, shape_dict, "llvm")

Output:

--------- Initial Relay IR ---------
v0.0.4
def @main(%dense_1_input: Tensor[(1, 784), float32], %v_param_1: Tensor[(512, 784), float32], %v_param_2: Tensor[(512), float32], %v_param_3: Tensor[(512, 512), float32], %v_param_4: Tensor[(512), float32], %v_param_5: Tensor[(10, 512), float32], %v_param_6: Tensor[(10), float32]) -> Tensor[(1, 10), float32] {
  %0 = nn.dense(%dense_1_input, %v_param_1, units=512) /* ty=Tensor[(1, 512), float32] */;
  %1 = nn.bias_add(%0, %v_param_2) /* ty=Tensor[(1, 512), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 512), float32] */;
  %3 = nn.dense(%2, %v_param_3, units=512) /* ty=Tensor[(1, 512), float32] */;
  %4 = nn.bias_add(%3, %v_param_4) /* ty=Tensor[(1, 512), float32] */;
  %5 = nn.relu(%4) /* ty=Tensor[(1, 512), float32] */;
  %6 = nn.dense(%5, %v_param_5, units=10) /* ty=Tensor[(1, 10), float32] */;
  %7 = nn.bias_add(%6, %v_param_6) /* ty=Tensor[(1, 10), float32] */;
  nn.softmax(%7, axis=1) /* ty=Tensor[(1, 10), float32] */
}

--------- Quantized Relay IR ---------
v0.0.4
def @main(%dense_1_input: Tensor[(1, 784), float32]) -> Tensor[(1, 10), float32] {
  %0 = nn.dense(%dense_1_input, meta[relay.Constant][0] /* ty=Tensor[(512, 784), float32] */ /* ty=Tensor[(512, 784), float32] */, units=512) /* ty=Tensor[(1, 512), float32] */;
  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(512), float32] */ /* ty=Tensor[(512), float32] */) /* ty=Tensor[(1, 512), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 512), float32] */;
  %3 = nn.dense(%2, meta[relay.Constant][2] /* ty=Tensor[(512, 512), float32] */ /* ty=Tensor[(512, 512), float32] */, units=512) /* ty=Tensor[(1, 512), float32] */;
  %4 = add(%3, meta[relay.Constant][3] /* ty=Tensor[(512), float32] */ /* ty=Tensor[(512), float32] */) /* ty=Tensor[(1, 512), float32] */;
  %5 = nn.relu(%4) /* ty=Tensor[(1, 512), float32] */;
  %6 = nn.dense(%5, meta[relay.Constant][4] /* ty=Tensor[(10, 512), float32] */ /* ty=Tensor[(10, 512), float32] */, units=10) /* ty=Tensor[(1, 10), float32] */;
  %7 = add(%6, meta[relay.Constant][5] /* ty=Tensor[(10), float32] */ /* ty=Tensor[(10), float32] */) /* ty=Tensor[(1, 10), float32] */;
  nn.softmax(%7, axis=1) /* ty=Tensor[(1, 10), float32] */
}

I tried building the same Relay graph with the Relay API and that seemed to quantize okay.

@thierry @vinx13

Update to this: Uncommenting the line linked to by @thierry does indeed activate dense layer quantization. I saw nn.dense types change from float32 -> int8 in Relay IR and I see that the W matrix values have been changed to int8 as well after relay.build.

Unfortunately nn.bias_add and activations do not quantize. I’ve made another post about these.

Regarding activation, I have tried setting dtype_activation to int8 in the config.

Hi all,

was there any progress on this issue in the meantime @adb? I’m experiencing the same limitation, and as of today’s git pull, the quantization of dense layers is still commented out here. Why is this so? What are the reasons behind this decision? Could anyone please comment from the TVM developers, maybe @thierry or @vinx13 or @ziheng?

Thank you in advance & Best regards, Robert

1 Like