[TensorflowLite][Relay] relay.concatenate requires all tensors have the same dtype

I am getting this error relay.concatenate requires all tensors have the same dtype when I apply the relay.transform.InferType transformation on a model imported from TensorflowLite.

This is the script I use to generate the model:

import tensorflow as tf
import tensorflow_datasets as tfds
import cv2
import numpy as np

input_shape = (224, 224, 3)

model = tf.keras.applications.mobilenet.MobileNet(
    input_shape=input_shape,
    weights='imagenet',
    )

def representative_dataset_gen():
    dataset = tfds.load('imagenet_v2', split='test', shuffle_files=True)
    assert isinstance(dataset, tf.data.Dataset)
    for example in dataset.take(100):
        image = example["image"].numpy()
        scaled = cv2.resize(image, input_shape[:2])
        converted = cv2.cvtColor(scaled, cv2.COLOR_RGB2BGR)
        reshaped = converted.reshape((1,) + input_shape)
        #reshaped = reshaped / 255.
        yield [reshaped.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
#converter.target_spec.supported_types = [tf.int8]
converter.target_spec.supported_ops = [
    #tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
    tf.lite.OpsSet.TFLITE_BUILTINS_INT8, # enable TensorFlow Lite int8 ops.
    tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
]
converter.allow_custom_ops = False
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.representative_dataset = representative_dataset_gen

tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

This is how I import the model into TVM:

import tvm
from tvm import relay

tflite_file = "model.tflite"
try:
    import tflite

    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
    import tflite.Model

    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

input_tensor = "layer1_input"
input_shape = tuple([224, 224, 3])
input_dtype = "int8"

mod, params = relay.frontend.from_tflite(
    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}
)

Finally, I just apply a preprocess function to the imported model. In order to debug, I created a simple preprocess function which only runs the InferType transformation:

def preprocess_pass(mod);
    mod = relay.transform.InferType()(mod)
    return mod

preprocess_pass(mod)

This is the complete message of the error:

Traceback (most recent call last):
  File "main.py", line 70, in <module>
    mod = preprocess_pass(mod)
  File "main.py", line 66, in preprocess_pass
    mod = relay.transform.InferType()(mod)
  File "/local_disk/local_sw/tvm/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/local_disk/local_sw/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay9transform9InferTypeEvEUlS5_RKS7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SH_SL_
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve() [clone .cold]
  File "/local_disk/local_sw/tvm/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: relay.concatenate requires all tensors have the same dtype

Has anyone encountered a similar problem? I saw a similar problem with ONNX to relay but it seems that was a special case for ONNX.

I got the exact same problem, could you please tell me how did you solve this problem? Thank you!

Have you quantized your model in TFLite? This will cause it to contain layers of different data types: The input quantization will operate on float inputs and produce int outputs, subsequent layers will operate on integer inputs and outputs and the final layer will produce a float representation again.

You could try to fully quantize the model, removing the conversion layers and directly operating on integer inputs: Post-training integer quantization  |  TensorFlow Lite

This issue could be related to this one I’m having, with TVM’s native quantization.

For me, I believe that some inputs to the concat are int8, others are int32.

Assuming that’s the case, it seems that perhaps a prudent fix would be to add a conversion step before any quantised concat operation. However, it would probably be more performant to do some graph-level analysis, so that we can ideally do our conversions in-place, assuming our dependencies allow it.