[Tensorflow2] Conv2D batch=None, strides=2, padding=SAME failed

I found that the following combinations of Conv2D parameters failed during Relay vm.run execution.

# Conv2D failed for:
input shape=[None,300,300,3]
strides=[1,2,2,1]
padding="SAME" 

Error: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (150 == int32(arg2.shape[1])), Argument arg2.shape[1] has an unsatisfied constraint: (150 == int32(arg2.shape[1]))

Example:

import tensorflow as tf
import tvm
from tvm import relay
from tvm.relay.frontend import tensorflow2
import numpy as np
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

filters = tf.random.uniform(shape=(3,3,3,32), dtype=tf.float32)

@tf.function(input_signature=(tf.TensorSpec(shape=[None,300,300,3], dtype=tf.float32),))
def fun1(input_tensor):
  return tf.nn.conv2d(input_tensor, filters=filters, strides=2, padding='SAME')

concrete_func = fun1.get_concrete_function()
frozen_func = convert_variables_to_constants_v2(concrete_func, lower_control_flow=False)
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
print(graph_def)

data = tf.ones(shape=(8,300,300,3), dtype=tf.float32)
res = frozen_func(data)
print("TF res shape:", res[0].shape)

print("tensorflow2.from_tensorflow...")
mod, params = tensorflow2.from_tensorflow(graph_def)

print("relay.vm.compile...")
with tvm.transform.PassContext():
    vm_exec = relay.vm.compile(mod, target="llvm", params=params)

print("TVM running...")
vm = tvm.runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
vm_data = np.ones(shape=(8,300,300,3), dtype="float32")
vm_res = vm.run(vm_data)
print("TVM res shape:", vm_res.shape)

Error:

TVM running...
Traceback (most recent call last):
  File "/Users/pivovaa/aaa/tvm/my_conv2d.py", line 38, in <module>
    vm_res = vm.run(vm_data)
  File "/Users/pivovaa/aaa/tvm/python/tvm/runtime/vm.py", line 464, in run
    return self.invoke("main", *args, **kwargs)
  File "/Users/pivovaa/aaa/tvm/python/tvm/runtime/vm.py", line 446, in invoke
    return self._invoke(func_name)
  File "/Users/pivovaa/aaa/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) 9   ???                                 0x00007ffeeaf6a720 0x0 + 140732840453920
  [bt] (7) 8   libffi.8.dylib                      0x0000000105524db2 ffi_call_unix64 + 82
  [bt] (6) 7   libtvm.dylib                        0x000000016f211d76 TVMFuncCall + 70
  [bt] (5) 6   libtvm.dylib                        0x000000016f280f68 std::__1::__function::__func<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 1048
  [bt] (4) 5   libtvm.dylib                        0x000000016f27c897 tvm::runtime::vm::VirtualMachine::RunLoop() + 10087
  [bt] (3) 4   libtvm.dylib                        0x000000016f27ef61 tvm::runtime::vm::VirtualMachine::InvokePacked(long long, tvm::runtime::PackedFunc const&, long long, long long, std::__1::vector<tvm::runtime::ObjectRef, std::__1::allocator<tvm::runtime::ObjectRef> > const&) + 1217
  [bt] (2) 3   libtvm.dylib                        0x000000016f225bd3 std::__1::__function::__func<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*, void*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*, void*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 419
  [bt] (1) 2   libtvm.dylib                        0x000000016ddf9939 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
  [bt] (0) 1   libtvm.dylib                        0x000000016f2264a8 tvm::runtime::Backtrace() + 24
  File "/Users/pivovaa/bbb/tvm/src/runtime/library_module.cc", line 78
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: ret == 0 (-1 vs. 0) : Assert fail: (150 == int32(arg2.shape[1])), Argument arg2.shape[1] has an unsatisfied constraint: (150 == int32(arg2.shape[1]))

Process finished with exit code 1

I guess the validation above checks that Conv2D output has shape [8,150,150,32], but the actual output has wrong shape at runtime

@yongwww @comaniac

If I change tf.function TensorSpec shape from [?,300x300x3] to [?,319x319x3] vm.run works fine.

@tf.function(input_signature=(tf.TensorSpec(shape=[None,319,319,3], dtype=tf.float32),))
def fun1(input_tensor):
  return tf.nn.conv2d(input_tensor, filters=filters, strides=2, padding='SAME')

you could try to provide a fixed shape value for Any of end, the value should be the same as input_tensor, or -1. for example end = [-1, -1, -1, 3] in your test case.

One of the workarounds is to add _op.reshape with a particular batch size before Conv2D op.

diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py
index 813ee2eab..7f0261a50 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -467,6 +467,10 @@ def _conv(opname):
             else:
                 attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"
 
+        if isinstance(input_shape[0], tvm.tir.Any):
+            new_input_shape = (batchSize,) + input_shape[1:]
+            inputs_data = _op.reshape(inputs_data, newshape=new_input_shape)
+
         # Ignore the new attributes from TF2.0, for now.
         out = AttrCvt(
             op_name=_dimension_picker(