Error when converting TFLite Conv2D kernel layout to OIHW4o

My custom accelerator’s expected kernel layout for convolution is OIHW4o, and I have a tflite file as input, which uses the IOHW kernel layout.

To convert the kernel layout, I am trying to use the ConvertLayout pass. However, I am getting an error during the process, stating that the dimensions do not match.

To produce the error and expected output, I would like to give two scripts below. In the first script, I try to convert the layout from an input tflite file and get an error during the ConvertLayout pass. In the second script, I do the same but create the input network directly from relay, where the pass works as expected. In both scripts, the input defines roughly the same network.

The versions of dependencies are as follows:

  • TVM: v0.16.dev0
  • Tensorflow: 2.12.1
  • Numpy: 1.24.3

Following is the erroneous layout conversion when I use tflite file as input:

import tensorflow as tf
from tvm.testing.aot import create_relay_module_and_inputs_from_tflite_file
import numpy as np
import tvm
from tvm import relay


def generate_keras_model():
    tf_model = tf.keras.models.Sequential(
        [
            tf.keras.Input(shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(8, (3, 3), padding="valid", activation=None),
        ]
    )
    return tf_model


def convert_to_tflite(keras_model, filename):
    def representative_dataset():
        for _ in range(100):
            data = np.random.rand(1, 28, 28, 1)
            yield [data.astype(np.float32)]

    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8  # or tf.uint8
    converter.inference_output_type = tf.int8  # or tf.uint8
    tflite_quant_model = converter.convert()
    with open(filename, "wb") as f:
        f.write(tflite_quant_model)



if __name__ == "__main__":
    keras_model = generate_keras_model()
    tflite_model_filename = "model.tflite"
    convert_to_tflite(keras_model, tflite_model_filename)

    mod, _, params = create_relay_module_and_inputs_from_tflite_file(
        tflite_model_filename, bind_params_by_name=False
    )

    desired_layouts = {'qnn.conv2d': ['NCHW', 'OIHW4o']}

    seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)

    print(mod)


    # We get the following error:
    # Error: The Relay type checker is unable to show the following types match:
    #   Tensor[(2), float32]
    #   Tensor[(8), float32]
    # In particular:
    #   dimension 0 conflicts: 2 does not match 8.

On the other hand, following is the script performing as expected:

import tvm
from tvm import relay


def generate_relay_model():
    data_shape = [1, 28, 28, 1]  # NHWC
    weight_shape = [3, 3, 1, 8]  # HWIO
    data = relay.var("data", shape=data_shape, dtype="int8")
    weight = relay.var("weight", shape=weight_shape, dtype="int8")
    op2 = relay.qnn.conv2d(
        data,
        weight,
        input_zero_point=relay.const(0),
        kernel_zero_point=relay.const(0),
        input_scale=relay.const(0.078),
        kernel_scale=relay.const(0.07),
        padding=[0, 0, 0, 0],
        channels=8,
        kernel_size=[3, 3],
        kernel_layout="HWIO",
        data_layout="NHWC"
    )
    op3 = relay.nn.bias_add(
        op2,
        relay.const([0]*8),
        axis=3
    )

    op5 = relay.qnn.requantize(
        op3,
        input_scale=relay.const(0.05),
        input_zero_point=relay.const(0),
        output_scale=relay.const(0.21),
        output_zero_point=relay.const(61),
        out_dtype="int8",
    )
    relay_mod = tvm.IRModule.from_expr(op5)
    return relay_mod

if __name__ == "__main__":
    relay_mod = generate_relay_model()
    relay_mod = relay.transform.InferType()(relay_mod)
    relay_mod = relay.transform.ConvertLayout({"qnn.conv2d": ["NCHW", "OIHW4o"]})(relay_mod)

    print(relay_mod) # Works nicely.

Given that the both scripts try to define the same graph but I am getting an error when the source is tflite, it seems that there is a bug.

Since I would like to support tflite for my custom accelerator, I am looking forward to your help. Thanks in advance!