[Tensorflow2] StridedSlice for [None,None,None,3] graph fails

Tensorflow2 zoo OD models inference graphs use input shape [None,None,None,3]. research/object_detection/exporter_lib_v2.py

TVM tensorflow2.from_tensorflow fails to convert StridedSlice op for graph node with shape [None,None,None,3].

Code example:

import tensorflow as tf
import tvm
from tvm import relay
from tvm.relay.frontend import tensorflow2
import numpy as np

@tf.function(input_signature=(tf.TensorSpec(shape=[None,None,None,3], dtype=tf.float32),))
def fun1(input_tensor):
  return input_tensor[:,::2,::2,:]

concrete_func = fun1.get_concrete_function()
print("concrete_func signature:")
print(concrete_func.pretty_printed_signature())
graph_def = concrete_func.graph.as_graph_def(add_shapes=True)
print("First graph_def node:", graph_def.node[0].name)

data = tf.ones(shape=(8,20,20,3), dtype=tf.float32)
res = concrete_func(data)
print("TF res:", res)

print("tensorflow2.from_tensorflow...")
mod, params = tensorflow2.from_tensorflow(graph_def)

print("relay.vm.compile...")
with tvm.transform.PassContext():
    vm_exec = relay.vm.compile(mod, target="llvm", params=params)

print("vm.run...")
vm = tvm.runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
vm_data = np.ones(shape=(8,20,20,3), dtype="float32")
vm_res = vm.run(vm_data)
print("TVM res:", vm_res)

Error:

tensorflow2.from_tensorflow...
Traceback (most recent call last):
  File "/Users/pivovaa/aaa/tvm/my_test.py", line 23, in <module>
    mod, params = tensorflow2.from_tensorflow(graph_def)
  File "/Users/pivovaa/aaa/tvm/python/tvm/relay/frontend/tensorflow2.py", line 854, in from_tensorflow
    func, params = g.from_tensorflow(graph_def, layout, shape, outputs, gdef_lib=graph_def_library)
  File "/Users/pivovaa/aaa/tvm/python/tvm/relay/frontend/tensorflow2.py", line 231, in from_tensorflow
    graph, layout=layout, shape=shape, outputs=outputs, input_types=input_types
  File "/Users/pivovaa/aaa/tvm/python/tvm/relay/frontend/tensorflow2.py", line 394, in _get_relay_func
    self._backtrack_construct(graph, node.name)
  File "/Users/pivovaa/aaa/tvm/python/tvm/relay/frontend/tensorflow2.py", line 589, in _backtrack_construct
    op = self._convert_operator(graph, node.op, node.name, inputs, attr)
  File "/Users/pivovaa/aaa/tvm/python/tvm/relay/frontend/tensorflow2.py", line 485, in _convert_operator
    sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod)
  File "/Users/pivovaa/aaa/tvm/python/tvm/relay/frontend/tensorflow_ops.py", line 2313, in _impl
    out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
  File "/Users/pivovaa/aaa/tvm/python/tvm/relay/op/transform.py", line 941, in strided_slice
    return _make.strided_slice(data, begin, end, strides, slice_mode, axes)
  File "/Users/pivovaa/aaa/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (6) 7   ???                                 0x00007ffeed665940 0x0 + 140732881328448
  [bt] (5) 6   libffi.8.dylib                      0x0000000102e25db2 ffi_call_unix64 + 82
  [bt] (4) 5   libtvm.dylib                        0x000000016cb12d76 TVMFuncCall + 70
  [bt] (3) 4   libtvm.dylib                        0x000000016c5c898a void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::String, tvm::runtime::Optional<tvm::runtime::Array<tvm::Integer, void> >)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::String, tvm::runtime::Optional<tvm::runtime::Array<tvm::Integer, void> >)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::String, tvm::runtime::Optional<tvm::runtime::Array<tvm::Integer, void> >), std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 986
  [bt] (2) 3   libtvm.dylib                        0x000000016b8bf05d tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::runtime::Array<tvm::Integer, void><tvm::runtime::Array<tvm::Integer, void> >() const + 333
  [bt] (1) 2   libtvm.dylib                        0x000000016b6fa939 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
  [bt] (0) 1   libtvm.dylib                        0x000000016cb274a8 tvm::runtime::Backtrace() + 24
  [bt] (8) 9   ???                                 0x00007ffeed665940 0x0 + 140732881328448
  [bt] (7) 8   libffi.8.dylib                      0x0000000102e25db2 ffi_call_unix64 + 82
  [bt] (6) 7   libtvm.dylib                        0x000000016cb12d76 TVMFuncCall + 70
  [bt] (5) 6   libtvm.dylib                        0x000000016c5c898a void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::String, tvm::runtime::Optional<tvm::runtime::Array<tvm::Integer, void> >)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::String, tvm::runtime::Optional<tvm::runtime::Array<tvm::Integer, void> >)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::String, tvm::runtime::Optional<tvm::runtime::Array<tvm::Integer, void> >), std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 986
  [bt] (4) 5   libtvm.dylib                        0x000000016b8bef2c tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::runtime::Array<tvm::Integer, void><tvm::runtime::Array<tvm::Integer, void> >() const + 28
  [bt] (3) 4   libtvm.dylib                        0x000000016b8bf16e tvm::runtime::TVMMovableArgValue_::operator tvm::runtime::Array<tvm::Integer, void><tvm::runtime::Array<tvm::Integer, void>, void>() const + 158
  [bt] (2) 3   libtvm.dylib                        0x000000016b8bf549 tvm::runtime::Array<tvm::Integer, void> tvm::runtime::TVMPODValue_::AsObjectRef<tvm::runtime::Array<tvm::Integer, void> >() const + 841
  [bt] (1) 2   libtvm.dylib                        0x000000016b6fa939 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
  [bt] (0) 1   libtvm.dylib                        0x000000016cb274a8 tvm::runtime::Backtrace() + 24
  File "/Users/pivovaa/bbb/tvm/include/tvm/runtime/packed_func.h", line 714
TVMError: In function relay.op._make.strided_slice: error while converting argument 2: [23:45:10] /Users/pivovaa/bbb/tvm/include/tvm/runtime/packed_func.h:1591: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (!checked_type.defined()) is false: Expected Array[IntImm], but got Array[index 0: tir.Any]


Process finished with exit code 1

Problem caused in _transform_mask code

    if mask & end_mask:
        m_end[final_index] = (
            -(data_shape[final_index] + 1)
            if stride[index] < 0
            else data_shape[final_index]
        )

m_end list is constructed from dimensions used in data_shape (input[0] shape) (which might have undefined dimensions, e.g [Any, Any, Any,3] or [8,Any,Any,3]

_transform_mask code copies Any dimensions from data_shape to m_end list.

As a result end parameter will be [Any,Any,Any,3] or [8,Any,Any,3] which is not supported by relay _make.strided_slice(data, begin, end, strides, slice_mode, axes)

you could try to use either end=shape_of(input) or end=[-1, -1, -1, 3]. end argument could be an expression.