[PyTorch] Failed to convert Mask R-CNN with batch size > 1

Recerntly I am working on some models from PyTorch implementation, so I use the Relay PyTorch frontend from_pytorch to convert the model. Everything works well in general, but the problem comes when I convert Mask R-CNN with batch size larger than 1.

Specifically, I basically use the script from the tutorial (Compile PyTorch Object Detection Models — tvm 0.8.dev0 documentation) and change the batch size to other numbers. Here is the script to reproduce the error

import numpy as np

import torch
import torchvision

import tvm
from tvm import relay

batch_size = 2 # Error when > 1
image_size = [300, 300]
input_shape = (batch_size, 3, *image_size)

def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace

def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]

class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])

model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))

model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=input_shape))

with torch.no_grad():
    out = model(inp)
    script_module = do_trace(model, inp)

input_name = "input0"
shape_list = [(input_name, input_shape)]
relay.frontend.from_pytorch(script_module, shape_list)

And here is the error message I got when batch size was set to 2:

Traceback (most recent call last):
  File "model_od.py", line 50, in <module>
    relay_model, params = relay.frontend.from_pytorch(script_module, shape_list)
  File "/home/ubuntu/cody-tvm/python/tvm/relay/frontend/pytorch.py", line 3132, in from_pytorch
    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
  File "/home/ubuntu/cody-tvm/python/tvm/relay/frontend/pytorch.py", line 2554, in convert_operators
    inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype)
  File "/home/ubuntu/cody-tvm/python/tvm/relay/frontend/pytorch.py", line 509, in split_with_sizes
    split_index += index
  File "/home/ubuntu/cody-tvm/python/tvm/relay/expr.py", line 105, in __radd__
    return self.__add__(other)
  File "/home/ubuntu/cody-tvm/python/tvm/relay/expr.py", line 100, in __add__
    raise TypeError('convert "%s" with `const` first' % str(other))
TypeError: convert "0" with `const` first

Please note that I could successfully convert image classification models with batch size > 1 with this tutorial (Compile PyTorch Models — tvm 0.8.dev0 documentation), so I guess it might be due to some special processes for object detection models?

cc @masahi @alexwong

1 Like

I investigated this issue. Indeed, when the batch size is not 1, this code path is not taken

so we need to do infer value here tvm/pytorch.py at 50e013dd3a5e23450ff4ae98324be07aa6160a6d · apache/tvm · GitHub

it seems sections[i] is a complicated expression that doesn’t const-evaluate to integer, and our relay.split op doesn’t support non-constant split section.

So unfortunately I don’t see a way to workaround this issue.

1 Like

Thanks for the investigation and I got the point now. I’ll try to find other workarounds from the model itself, although it doesn’t sound promising…