[Pytorch] Error when turning a simple Pytorch model to Relay

Description

Hi, I encountered some errors when turning a pytorch model to relay. After doing some debugging, I managed to reproduce the bug with a minimal example as below:

minimal_example = torch.nn.Sequential(
    torch.nn.ConvTranspose2d(6, 6, kernel_size=(2, 2), groups=2),
    torch.nn.GroupNorm(6, 6)
)

If you turn this model to relay, there will be an error:

DiagnosticError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x000000012db03f57 void std::__1::__invoke_void_return_wrapper<void>::__call<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(tvm::relay::transform::InferType()::$_1&&...) + 71
  [bt] (7) 8   libtvm.dylib                        0x000000012db03fc7 decltype(std::__1::forward<tvm::relay::transform::InferType()::$_1>(fp)(std::__1::forward<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&>(fp0)...)) std::__1::__invoke<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(tvm::relay::transform::InferType()::$_1&&, void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)&...) + 71
  [bt] (6) 7   libtvm.dylib                        0x000000012db046b4 void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_1>(tvm::relay::transform::InferType()::$_1)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 1748
  [bt] (5) 6   libtvm.dylib                        0x000000012db04bc7 tvm::relay::transform::InferType()::$_1::operator()(tvm::IRModule, tvm::transform::PassContext const&) const + 1063
  [bt] (4) 5   libtvm.dylib                        0x000000012b753643 tvm::DiagnosticContext::Render() + 467
  [bt] (3) 4   libtvm.dylib                        0x000000012b139d75 tvm::runtime::detail::LogFatal::~LogFatal() + 21
  [bt] (2) 3   libtvm.dylib                        0x000000012b13bffd tvm::runtime::detail::LogFatal::~LogFatal() + 29
  [bt] (1) 2   libtvm.dylib                        0x000000012b13c0ac tvm::runtime::detail::LogFatal::Entry::Finalize() + 156
  [bt] (0) 1   libtvm.dylib                        0x000000012e20e365 tvm::runtime::Backtrace() + 37
  File "tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

Check the diagnostics:

# tvm.transform.PassContext.current().diag_ctx.diagnostics

The Relay type checker is unable to show the following types match.
In particular dimension 1 conflicts: 1 does not match 3.
---------
The Relay type checker is unable to show the following types match.
In particular `Tensor[(6, 3, 2, 2), float32]` does not match `Tensor[(6, 1, 2, 2), float32]`
---------
The Relay type checker is unable to show the following types match.
In particular dimension 0 conflicts: 3 does not match 6.
---------
The Relay type checker is unable to show the following types match.
In particular `Tensor[(6), float32]` does not match `Tensor[(3), float32]`
---------

Code to reproduce

import torch
import tvm
from tvm import relay

minimal_example = torch.nn.Sequential(
    torch.nn.ConvTranspose2d(6, 6, kernel_size=(2, 2), groups=2),
    torch.nn.GroupNorm(6, 6)
)

random_input = torch.randn([1, 6, 112, 112])
# It is OK to run the model in pytorch
result = minimal_example(random_input)
# compile the model
random_input = [random_input]
trace = torch.jit.trace(minimal_example, random_input)
input_names = ["input{}".format(idx) for idx, inp in enumerate(random_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in random_input]))
mod, params = tvm.relay.frontend.from_pytorch(trace, input_shapes) # Error

Environment

TVM: 0.8.dev0 at bf20107ffe6e96e20125a2209500668777095337

Pytorch version: 1.8.1

OS version: macOS 10.15.7


If this is really a bug, should I open an issue at TVM repo?

Unfortunately it seems we don’t support conv2d transpose with groups > 1. tvm/generic.py at 4344540ad4206733dd136678180fbf7e3dd616c3 · apache/tvm · GitHub

You are welcome to open an issue.

Thanks for the hint! I will open an issue to track this.