Description
Hi, I encountered some errors when turning a pytorch model to relay. After doing some debugging, I managed to reproduce the bug with a minimal example as below:
minimal_example = torch.nn.Sequential(
torch.nn.ConvTranspose2d(6, 6, kernel_size=(2, 2), groups=2),
torch.nn.GroupNorm(6, 6)
)
If you turn this model to relay, there will be an error:
DiagnosticError: Traceback (most recent call last):
[bt] (8) 9 libtvm.dylib 0x000000012db03f57 void std::__1::__invoke_void_return_wrapper<void>::__call<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(tvm::relay::transform::InferType()::$_1&&...) + 71
[bt] (7) 8 libtvm.dylib 0x000000012db03fc7 decltype(std::__1::forward<tvm::relay::transform::InferType()::$_1>(fp)(std::__1::forward<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&>(fp0)...)) std::__1::__invoke<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(tvm::relay::transform::InferType()::$_1&&, void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&...) + 71
[bt] (6) 7 libtvm.dylib 0x000000012db046b4 void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 1748
[bt] (5) 6 libtvm.dylib 0x000000012db04bc7 tvm::relay::transform::InferType()::$_1::operator()(tvm::IRModule, tvm::transform::PassContext const&) const + 1063
[bt] (4) 5 libtvm.dylib 0x000000012b753643 tvm::DiagnosticContext::Render() + 467
[bt] (3) 4 libtvm.dylib 0x000000012b139d75 tvm::runtime::detail::LogFatal::~LogFatal() + 21
[bt] (2) 3 libtvm.dylib 0x000000012b13bffd tvm::runtime::detail::LogFatal::~LogFatal() + 29
[bt] (1) 2 libtvm.dylib 0x000000012b13c0ac tvm::runtime::detail::LogFatal::Entry::Finalize() + 156
[bt] (0) 1 libtvm.dylib 0x000000012e20e365 tvm::runtime::Backtrace() + 37
File "tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.
Check the diagnostics:
# tvm.transform.PassContext.current().diag_ctx.diagnostics
The Relay type checker is unable to show the following types match.
In particular dimension 1 conflicts: 1 does not match 3.
---------
The Relay type checker is unable to show the following types match.
In particular `Tensor[(6, 3, 2, 2), float32]` does not match `Tensor[(6, 1, 2, 2), float32]`
---------
The Relay type checker is unable to show the following types match.
In particular dimension 0 conflicts: 3 does not match 6.
---------
The Relay type checker is unable to show the following types match.
In particular `Tensor[(6), float32]` does not match `Tensor[(3), float32]`
---------
Code to reproduce
import torch
import tvm
from tvm import relay
minimal_example = torch.nn.Sequential(
torch.nn.ConvTranspose2d(6, 6, kernel_size=(2, 2), groups=2),
torch.nn.GroupNorm(6, 6)
)
random_input = torch.randn([1, 6, 112, 112])
# It is OK to run the model in pytorch
result = minimal_example(random_input)
# compile the model
random_input = [random_input]
trace = torch.jit.trace(minimal_example, random_input)
input_names = ["input{}".format(idx) for idx, inp in enumerate(random_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in random_input]))
mod, params = tvm.relay.frontend.from_pytorch(trace, input_shapes) # Error
Environment
TVM: 0.8.dev0 at bf20107ffe6e96e20125a2209500668777095337
Pytorch version: 1.8.1
OS version: macOS 10.15.7
If this is really a bug, should I open an issue at TVM repo?