Pytorch slice produces wrong results

Hi everyone, I’m having a problem when compiling a model from pytorch. Here is what I’m doing, specifically, I tried to stack adjecent features vertically.

import numpy as np
import torch
import tvm
from tvm import relay
from tvm.contrib import graph_runtime

class PixelStack(torch.nn.Module):
    def __init__(self, stackit=False):
        super(PixelStack, self).__init__()
        self.stackit = stackit

    def forward(self, x):
        assert len(x.size()) == 4 and x.size(2) % 2 == 0 and x.size(3) % 2 == 0
        if self.stackit:
            return torch.cat([x[...,::2, ::2], x[...,1::2, ::2], x[...,::2, 1::2], x[...,1::2, 1::2]], dim=1)
        else:
            return (x[...,::2, ::2], x[...,1::2, ::2], x[...,::2, 1::2], x[...,1::2, 1::2])

x = torch.randn(1, 128, 32, 32)
f = PixelStack()
f_ts = torch.jit.trace(f, x)

mod, params = relay.frontend.from_pytorch(f_ts, [('input0', (1, 128, 32, 32))])
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target='llvm', target_host='llvm', params=params)
f_tvm = graph_runtime.GraphModule(lib['default'](tvm.cpu(0)))
f_tvm.set_input('input0', tvm.nd.array(x.numpy()))

expected output:

for i in f_ts(x):
    print(i.size())
torch.Size([1, 128, 16, 16])
torch.Size([1, 128, 16, 16])
torch.Size([1, 128, 16, 16])
torch.Size([1, 128, 16, 16])

however, output from tvm look like this…

f_tvm.run()
for i in range(4):
    print(f_tvm.get_output(i).shape)
(1, 128, 32, 32)
(1, 128, 31, 32)
(1, 128, 32, 31)
(1, 128, 31, 31)

if I pile them up, of course I got the error

relay.concatenate requires all tensors have the same shape on non-concatenating axes;

TVMError: Traceback (most recent call last):
  [bt] (8) /home/ds/Workspace/tvm/build/libtvm.so(TVMFuncCall+0x65) [0x7fd700d68145]
  [bt] (7) /home/ds/Workspace/tvm/build/libtvm.so(+0x65c554) [0x7fd70039b554]
  [bt] (6) /home/ds/Workspace/tvm/build/libtvm.so(+0x65c0e7) [0x7fd70039b0e7]
  [bt] (5) /home/ds/Workspace/tvm/build/libtvm.so(tvm::IRModuleNode::Add(tvm::GlobalVar const&, tvm::BaseFunc const&, bool)+0x31f) [0x7fd70039999f]
  [bt] (4) /home/ds/Workspace/tvm/build/libtvm.so(tvm::RunTypeCheck(tvm::IRModule const&, tvm::GlobalVar const&, tvm::relay::Function)+0x2e7) [0x7fd700399007]
  [bt] (3) /home/ds/Workspace/tvm/build/libtvm.so(tvm::relay::InferType(tvm::relay::Function const&, tvm::IRModule const&, tvm::GlobalVar const&)+0x1bc) [0x7fd700bb45fc]
  [bt] (2) /home/ds/Workspace/tvm/build/libtvm.so(tvm::relay::TypeInferencer::Infer(tvm::RelayExpr)+0x71) [0x7fd700bb3e91]
  [bt] (1) /home/ds/Workspace/tvm/build/libtvm.so(tvm::ErrorReporter::RenderErrors(tvm::IRModule const&, bool)+0x228b) [0x7fd7003841db]
  [bt] (0) /home/ds/Workspace/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x82) [0x7fd7002af442]
  File "../src/ir/error.cc", line 132
TVMError: 
Error(s) have occurred. The program has been annotated with them:

In `main`: 
#[version = "0.0.5"]
fn (%input0: Tensor[(1, 128, 32, 32), float32]) {
  %0 = strided_slice(%input0, meta[relay.Constant][0], meta[relay.Constant][1], meta[relay.Constant][2], begin=[0, 0, 0, 0], end=[1, 128, -1, 32], strides=[2], slice_mode="size");
  %1 = strided_slice(%0, meta[relay.Constant][3], meta[relay.Constant][4], meta[relay.Constant][5], begin=[0, 0, 0, 0], end=[1, 128, 32, -1], strides=[2], slice_mode="size");
  %2 = strided_slice(%input0, meta[relay.Constant][6], meta[relay.Constant][7], meta[relay.Constant][8], begin=[0, 0, 1, 0], end=[1, 128, -1, 32], strides=[2], slice_mode="size");
  %3 = strided_slice(%2, meta[relay.Constant][9], meta[relay.Constant][10], meta[relay.Constant][11], begin=[0, 0, 0, 0], end=[1, 128, 31, -1], strides=[2], slice_mode="size");
  %4 = strided_slice(%input0, meta[relay.Constant][12], meta[relay.Constant][13], meta[relay.Constant][14], begin=[0, 0, 0, 0], end=[1, 128, -1, 32], strides=[2], slice_mode="size");
  %5 = strided_slice(%4, meta[relay.Constant][15], meta[relay.Constant][16], meta[relay.Constant][17], begin=[0, 0, 0, 1], end=[1, 128, 32, -1], strides=[2], slice_mode="size");
  %6 = strided_slice(%input0, meta[relay.Constant][18], meta[relay.Constant][19], meta[relay.Constant][20], begin=[0, 0, 1, 0], end=[1, 128, -1, 32], strides=[2], slice_mode="size");
  %7 = strided_slice(%6, meta[relay.Constant][21], meta[relay.Constant][22], meta[relay.Constant][23], begin=[0, 0, 0, 1], end=[1, 128, 31, -1], strides=[2], slice_mode="size");
  %8 = (%1, %3, %5, %7);
  concatenate(%8, axis=1) relay.concatenate requires all tensors have the same shape on non-concatenating axes; 
}
/* For debugging purposes the metadata section has been omitted.
 * If you would like to see the full metadata section you can set the 
 * option to `True` when invoking `astext`. 
 */

By the way, I’m using the lastest master branch tvm, I’ve tried tracing down where the code broken, and noticed that the inputs of the slice operator is weired, say, [2, 0, 9043433…32354, 2], the third number is a gigantic guy, but I’m not familiar with torchscript either. :sweat_smile:

I met the same error like " relay.concatenate requires all tensors have the same shape on non-concatenating axes; " when I run yolov5 ( struggling) I use tvm 0.10 and run the code, change into "import tvm.contrib.graph_executor as runtime ", the error did not show , maybe fixed.