[Bug] relax ONNX Resize operator conversion with opset_version=11 is unsupported

code:

from pathlib import Path
import torch
import torch.nn.functional as F
import onnx
from tvm.relax.frontend.onnx import from_onnx

class M(torch.nn.Module):
    def forward(self, x):
        x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode="nearest",)
        return x

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
torch_model = M()
input_tensor = torch.randn(1, 3, 10, 10)
torch.onnx.export(
    torch_model, 
    (input_tensor,), 
    temp_dir/"test.onnx", 
    input_names=["x"],
    opset_version=11,
)
model = onnx.load(temp_dir/"test.onnx")
tvm_model = from_onnx(model, keep_params_in_input=True)

bug:

Error converting operator Resize, with inputs: [x, metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it., metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.]

---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
Cell In[13], line 24
     16 torch.onnx.export(
     17     torch_model, 
     18     (input_tensor,), 
   (...)
     21     opset_version=11,
     22 )
     23 model = onnx.load(temp_dir/"test.onnx")
---> 24 tvm_model = from_onnx(model, keep_params_in_input=True)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3690, in from_onnx(model, shape_dict, dtype_dict, opset, keep_params_in_input, sanitize_input_names)
   3683     warnings.warn(
   3684         ""
   3685         f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "
   3686         f"That might cause model conversion errors."
   3687     )
   3689 # Use the graph proto as a scope so that ops can access other nodes if needed.
-> 3690 return g.from_onnx(graph, opset)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3321, in ONNXGraphImporter.from_onnx(self, graph, opset)
   3319 self._parse_graph_input(graph)
   3320 self._check_for_unsupported_ops(graph)
-> 3321 self._construct_nodes(graph)
   3323 # now return the outputs
   3324 outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3501, in ONNXGraphImporter._construct_nodes(self, graph)
   3499 except TVMError as err:
   3500     print(f"Error converting operator {op_name}, with inputs: {inputs}")
-> 3501     raise err
   3503 if op_name in return_tuple_ops:
   3504     outputs_num = 1

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3496, in ONNXGraphImporter._construct_nodes(self, graph)
   3494         raise ValueError(f"Node {node.name} cannot handle ShapeExpr inputs.")
   3495 try:
-> 3496     op = self._convert_operator(op_name, inputs, attr, self.opset)
   3497     # Create struct information for the new operator.
   3498     op = self.bb.normalize(op)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3596, in ONNXGraphImporter._convert_operator(self, op_name, inputs, attrs, opset)
   3594     convert_class = convert_map[op_name]
   3595     op_function = convert_class.get_converter(opset)
-> 3596     sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params])
   3597 else:
   3598     raise NotImplementedError("Operator {} not implemented.".format(op_name))

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:2146, in Resize._impl_v18(cls, bb, inputs, attr, params)
   2141     assert isinstance(
   2142         sizes, relax.Constant
   2143     ), "Only constant output size currently supported."
   2144     sizes = sizes.data.numpy().astype("int64").tolist()[2:]
-> 2146 return relax.op.image.resize2d(
   2147     x,
   2148     size=relax.ShapeExpr(sizes),
   2149     roi=roi,
   2150     layout="NCHW",
   2151     method=mode,
   2152     coordinate_transformation_mode=coord_mode,
   2153     rounding_method=rounding_method,
   2154     cubic_alpha=cubic_coeff_a,
   2155     cubic_exclude=exclude_outside,
   2156     extrapolation_value=extrapolation_value,
   2157 )

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/op/image/image.py:116, in resize2d(data, size, roi, layout, method, coordinate_transformation_mode, rounding_method, cubic_alpha, cubic_exclude, extrapolation_value, out_dtype)
    113     else:
    114         size = ShapeExpr(size)
--> 116 return _ffi_api.resize2d(  # type: ignore
    117     data,
    118     size,
    119     roi,
    120     layout,
    121     method,
    122     coordinate_transformation_mode,
    123     rounding_method,
    124     cubic_alpha,
    125     cubic_exclude,
    126     extrapolation_value,
    127     out_dtype,
    128 )

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:284, in tvm._ffi._cy3.core.FuncCall()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468, in raise_last_ffi_error()
    462 # The exception PyObject may contain a large amount of state,
    463 # including all stack frames that may be inspected in a later
    464 # PDB post-mortem.  Therefore, we must make sure to remove the
    465 # underlying PyObject* from the C++ side after we retrieve it.
    466 _LIB.TVMDropLastPythonError()
--> 468 raise py_err

TVMError: Traceback (most recent call last):
  File "/media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h", line 924
TVMError: In function relax.op.image.resize2d(0: RelaxExpr, 1: RelaxExpr, 2: Array<FloatImm>, 3: runtime.String, 4: runtime.String, 5: runtime.String, 6: runtime.String, 7: double, 8: int, 9: double, 10: DataType) -> RelaxExpr: error while converting argument 2: [10:17:02] /media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h:2274: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[runtime.Object], but got relax.expr.Call

now:

# `tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py`
class Resize(OnnxOpConverter):

    """Converts an onnx Resize node into an equivalent Relax expression."""

    @classmethod

    def _impl_v18(cls, bb, inputs, attr, params):
         ...

need

class Resize(OnnxOpConverter):
    """Converts an onnx Resize node into an equivalent Relax expression."""

    @classmethod
    def _impl_v10(cls, bb, inputs, attr, params):
        ...

    @classmethod
    def _impl_v11(cls, bb, inputs, attr, params):
        ...

    @classmethod
    def _impl_v13(cls, bb, inputs, attr, params):
        ...

    @classmethod
    def _impl_v18(cls, bb, inputs, attr, params):
        ...

Thanks for reporting the issue. I think it’s not hard to fix. Happy to review if you could create a PR :slight_smile:

@Hzfengsy PR see: Improve the Relax frontend interface for PyTorch-exported ONNX frontend models by xinetzone · Pull Request #17795 · apache/tvm