code:
from pathlib import Path
import torch
import torch.nn.functional as F
import onnx
from tvm.relax.frontend.onnx import from_onnx
class M(torch.nn.Module):
def forward(self, x):
x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode="nearest",)
return x
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
torch_model = M()
input_tensor = torch.randn(1, 3, 10, 10)
torch.onnx.export(
torch_model,
(input_tensor,),
temp_dir/"test.onnx",
input_names=["x"],
opset_version=11,
)
model = onnx.load(temp_dir/"test.onnx")
tvm_model = from_onnx(model, keep_params_in_input=True)
bug:
Error converting operator Resize, with inputs: [x, metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it., metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.]
---------------------------------------------------------------------------
TVMError Traceback (most recent call last)
Cell In[13], line 24
16 torch.onnx.export(
17 torch_model,
18 (input_tensor,),
(...)
21 opset_version=11,
22 )
23 model = onnx.load(temp_dir/"test.onnx")
---> 24 tvm_model = from_onnx(model, keep_params_in_input=True)
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3690, in from_onnx(model, shape_dict, dtype_dict, opset, keep_params_in_input, sanitize_input_names)
3683 warnings.warn(
3684 ""
3685 f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "
3686 f"That might cause model conversion errors."
3687 )
3689 # Use the graph proto as a scope so that ops can access other nodes if needed.
-> 3690 return g.from_onnx(graph, opset)
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3321, in ONNXGraphImporter.from_onnx(self, graph, opset)
3319 self._parse_graph_input(graph)
3320 self._check_for_unsupported_ops(graph)
-> 3321 self._construct_nodes(graph)
3323 # now return the outputs
3324 outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3501, in ONNXGraphImporter._construct_nodes(self, graph)
3499 except TVMError as err:
3500 print(f"Error converting operator {op_name}, with inputs: {inputs}")
-> 3501 raise err
3503 if op_name in return_tuple_ops:
3504 outputs_num = 1
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3496, in ONNXGraphImporter._construct_nodes(self, graph)
3494 raise ValueError(f"Node {node.name} cannot handle ShapeExpr inputs.")
3495 try:
-> 3496 op = self._convert_operator(op_name, inputs, attr, self.opset)
3497 # Create struct information for the new operator.
3498 op = self.bb.normalize(op)
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3596, in ONNXGraphImporter._convert_operator(self, op_name, inputs, attrs, opset)
3594 convert_class = convert_map[op_name]
3595 op_function = convert_class.get_converter(opset)
-> 3596 sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params])
3597 else:
3598 raise NotImplementedError("Operator {} not implemented.".format(op_name))
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:2146, in Resize._impl_v18(cls, bb, inputs, attr, params)
2141 assert isinstance(
2142 sizes, relax.Constant
2143 ), "Only constant output size currently supported."
2144 sizes = sizes.data.numpy().astype("int64").tolist()[2:]
-> 2146 return relax.op.image.resize2d(
2147 x,
2148 size=relax.ShapeExpr(sizes),
2149 roi=roi,
2150 layout="NCHW",
2151 method=mode,
2152 coordinate_transformation_mode=coord_mode,
2153 rounding_method=rounding_method,
2154 cubic_alpha=cubic_coeff_a,
2155 cubic_exclude=exclude_outside,
2156 extrapolation_value=extrapolation_value,
2157 )
File /media/pc/data/lxw/ai/tvm/python/tvm/relax/op/image/image.py:116, in resize2d(data, size, roi, layout, method, coordinate_transformation_mode, rounding_method, cubic_alpha, cubic_exclude, extrapolation_value, out_dtype)
113 else:
114 size = ShapeExpr(size)
--> 116 return _ffi_api.resize2d( # type: ignore
117 data,
118 size,
119 roi,
120 layout,
121 method,
122 coordinate_transformation_mode,
123 rounding_method,
124 cubic_alpha,
125 cubic_exclude,
126 extrapolation_value,
127 out_dtype,
128 )
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:284, in tvm._ffi._cy3.core.FuncCall()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468, in raise_last_ffi_error()
462 # The exception PyObject may contain a large amount of state,
463 # including all stack frames that may be inspected in a later
464 # PDB post-mortem. Therefore, we must make sure to remove the
465 # underlying PyObject* from the C++ side after we retrieve it.
466 _LIB.TVMDropLastPythonError()
--> 468 raise py_err
TVMError: Traceback (most recent call last):
File "/media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h", line 924
TVMError: In function relax.op.image.resize2d(0: RelaxExpr, 1: RelaxExpr, 2: Array<FloatImm>, 3: runtime.String, 4: runtime.String, 5: runtime.String, 6: runtime.String, 7: double, 8: int, 9: double, 10: DataType) -> RelaxExpr: error while converting argument 2: [10:17:02] /media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h:2274: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[runtime.Object], but got relax.expr.Call