[PyTorch] Unable to import a simple torch.nn.Linear to Relay

Hi, I tried to do the following to import a simple torch.nn.Linear to Relay:

import tvm
from tvm import relay

import torch

# Create PyTorch eager model
in_features = 300
out_features = 100
m = torch.nn.Linear(in_features, out_features)

# Create PyTorch JIT-traced model
batch_size = 10
shape_list = [("input0", (batch_size, in_features))]
input = torch.randn(shape_list[0][1])
sm = torch.jit.trace(m, input)

# Set up TVM config
target = tvm.target.Target("cuda")
dtype = "float32"

# Import the PyTorch graph to Relay
mod, params = relay.frontend.from_pytorch(sm, shape_list)

This gives:

Traceback (most recent call last):
  File "test_tvm_auto_scheduler.py", line 22, in <module>
    mod, params = relay.frontend.from_pytorch(sm, shape_list)
  File "/home/willfeng/tvm/python/tvm/relay/frontend/pytorch.py", line 3329, in from_pytorch
    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
  File "/home/willfeng/tvm/python/tvm/relay/frontend/pytorch.py", line 2753, in convert_operators
    self.record_output_type(relay_out)
  File "/home/willfeng/tvm/python/tvm/relay/frontend/pytorch.py", line 219, in record_output_type
    self.infer_type_with_prelude(output)
  File "/home/willfeng/tvm/python/tvm/relay/frontend/pytorch.py", line 167, in infer_type_with_prelude
    body = self.infer_type(val, self.prelude.mod)
  File "/home/willfeng/tvm/python/tvm/relay/frontend/pytorch.py", line 160, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/home/willfeng/tvm/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/willfeng/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve()
  9: TVMFuncCall
  8: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), void tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  0: bool tvm::relay::MatmulRel<tvm::relay::DenseAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/home/willfeng/tvm/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [23:55:13] /home/willfeng/tvm/src/relay/op/nn/nn.h:100: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: ((transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[1])) || (!transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[0]))) is false: MatmulRel: input dimension doesn't match, tensor_a shape=[10, 300], tensor_b shape=[300, 100]

with the error

Check failed: ((transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[1])) || (!transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[0]))) is false: MatmulRel: input dimension doesn't match, tensor_a shape=[10, 300], tensor_b shape=[300, 100]

Wondering is there something obvious that I should fix? Thanks!

Thanks for reporting the error, could relates to a recent bug. [BUG][FRONTEND] Pytorch Importer mm · Issue #8392 · apache/tvm · GitHub is opened to track the issue. cc @jcf94 @comaniac

hmm I’m on my dev branch but the script works in my environment. Maybe PyTorch version difference? I’m testing on PT 1.7.

[BUG][FRONTEND] Pytorch Importer mm · Issue #8392 · apache/tvm · GitHub is not related, see [Torch] Remove unused conversion by masahi · Pull Request #8397 · apache/tvm · GitHub

With PT 1.7, the input is the following PT module.

graph(%self : __torch__.torch.nn.modules.linear.Linear,
      %input : Float(10:300, 300:1, requires_grad=0, device=cpu)):
  %2 : Tensor = prim::GetAttr[name="bias"](%self)
  %3 : Tensor = prim::GetAttr[name="weight"](%self)
  %4 : Float(300:1, 100:300, requires_grad=1, device=cpu) = aten::t(%3) # /home/masa/anaconda3/envs/torch-1.7/lib/python3.8/site-packages/torch/nn/functional.py:1690:0
  %5 : int = prim::Constant[value=1]() # /home/masa/anaconda3/envs/torch-1.7/lib/python3.8/site-packages/torch/nn/functional.py:1690:0
  %6 : int = prim::Constant[value=1]() # /home/masa/anaconda3/envs/torch-1.7/lib/python3.8/site-packages/torch/nn/functional.py:1690:0
  %7 : Float(10:100, 100:1, requires_grad=1, device=cpu) = aten::addmm(%2, %input, %4, %5, %6) # /home/masa/anaconda3/envs/torch-1.7/lib/python3.8/site-packages/torch/nn/functional.py:1690:0
  return (%7)

Thanks @tqchen and @masahi. I tried the script on the following PyTorch versions:

torch 1.10.0.dev20210629+cu111 → throws error (original post)

torch 1.9.0 → throws same error as above

torch 1.8.1 → works

torch 1.7.0 → works (@masahi’s experiment)

I have a few questions:

  1. Is PyTorch 1.8.1 the currently supported version by TVM? Curious is there a CI job that test the integration I can take a look and learn more about :slight_smile:
  2. I would like to help add PyTorch 1.9.0 conversion support at least for torch.nn.Linear, curious where should I start? (I’ve been doing PyTorch internal development for several years, so any quick code pointers will be sufficient for me to get started :slight_smile: )

Thanks!

CI is running 1.7, so that’s the version we support. I’d expect things to work with 1.8 just as well, but I haven’t tested 1.9 yet.

There are two issues with nn.Linear as of v1.9:

So can you try the above fix and validate the result? If it is good, can you send a PR for the fix and your test case? @yf225 (Can follow [torch] Add linear operator support by apivovarov · Pull Request #7569 · apache/tvm · GitHub)

Just wanted to add that I’m also encountering this issue with Torch 1.9.0+cu102 and MobileNetV2 (without quantization).

The code:

inp = get_imagenet_input() # from https://tvm.apache.org/docs/tutorials/frontend/deploy_prequantized.html
model = models.mobilenet_v2(pretrained=True, progress=True)

with torch.no_grad():
    script_module = torch.jit.trace(model, inp).eval()

input_name = "input"
input_shapes = [(input_name, (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

Produces the error:

Check failed: ((transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[1])) || (!transpose_b && reporter->AssertEQ(reduce, tensor_b->shape[0]))) is false: MatmulRel: input dimension doesn't match, tensor_a shape=[1, 1280], tensor_b shape=[1280, 1000]

EDIT#1: PR#7569 is in my git log, so this may be a different thing that coincidentally produces the same class of error.

EDIT#2: Without testing PR#8622, this works with PyTorch v1.7.0.

Could you use the latest upstream TVM with this merged PR to see if the problem persists?

1 Like

Yeah ~ PR#8622 seems to resolve the issue! Thanks for the tip :+1: