[torch] matmul(1D, 2D) causes relay.build to fail

PyTorch torch.matmul supports (1D x 2D) case.

If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

https://pytorch.org/docs/stable/generated/torch.matmul.html

However, the compilation of such model in TVM fails:

  • matmul(1D, 2D) - relay.build fails
  • matmul(1D, 1D) - relay.build fails

Example to reproduce the error:

import torch
import tvm
from tvm import relay

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self, a, b):
        return torch.matmul(a, b)

a = torch.tensor([4,1])
b = torch.tensor([[1,0],[0,1]])

net = Net()
net(a, b)

traced_net = torch.jit.trace(net, (a, b))

ctx = tvm.cpu(0)
target = 'llvm'

shape_list = [("input0", [2]),("input1", [2,2]),]
mod, params = relay.frontend.from_pytorch(traced_net, shape_list)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

func=mod['main']
intrp = relay.create_executor("graph", ctx=ctx, target=target)
ff=intrp.evaluate(func)
ff([4,1], [[1,0],[0,1]])

Error:

>>> with tvm.transform.PassContext(opt_level=3):
...     lib = relay.build(mod, target=target, params=params)
... 
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/build_module.py", line 269, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/build_module.py", line 132, in build
    self._build(mod, target, target_host)
  File "/Users/pivovaa/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
ValueError: Traceback (most recent call last):
  [bt] (8) 9   libtvm.dylib                        0x0000000125b8fbc8 tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::'lambda4'(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)::__invoke(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) + 24
  [bt] (7) 8   libtvm.dylib                        0x0000000125b052fe tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*) + 14
  [bt] (6) 7   libtvm.dylib                        0x0000000125b07e59 tvm::RelayExpr tvm::relay::MixedModeMutator::Rewrite<tvm::relay::CallNode>(tvm::relay::CallNode const*) + 57
  [bt] (5) 6   libtvm.dylib                        0x0000000125b5cb4b tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&) + 1707
  [bt] (4) 5   libtvm.dylib                        0x0000000125b5df95 tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void>&, tvm::runtime::ObjectRef>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void>&, tvm::runtime::ObjectRef&&) const + 245
  [bt] (3) 4   libtvm.dylib                        0x0000000125ac0211 void tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 721
  [bt] (2) 3   libtvm.dylib                        0x0000000125ac4ca7 tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&) + 9639
  [bt] (1) 2   libtvm.dylib                        0x0000000125ac68bc tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::__1::vector<tvm::RelayExpr, std::__1::allocator<tvm::RelayExpr> > const&) + 1084
  [bt] (0) 1   libtvm.dylib                        0x0000000125e02c25 std::__1::__function::__func<TVMFuncCreateFromCFunc::$_2, std::__1::allocator<TVMFuncCreateFromCFunc::$_2>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 213
  File "/Users/pivovaa/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/op/nn/_nn.py", line 84, in alter_op_layout_dense
    return topi.nn.dense_alter_layout(attrs, inputs, tinfos, out_type)
  File "<decorator-gen-69>", line 2, in dense_alter_layout
  File "/Users/pivovaa/workspace/tvm/python/tvm/target/generic_func.py", line 275, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/Users/pivovaa/workspace/tvm/python/tvm/topi/x86/dense_alter_op.py", line 35, in _alter_dense_layout
    M, K = get_const_tuple(data_tensor.shape)
ValueError: not enough values to unpack (expected 2, got 1)