Issue on compiling FasterRCNN PyTorch model

Hi. I’m trying to compile PyTorch FasterRCNN model and import the graph to Relay. Here’s the code.

import tvm
from tvm import relay
import torch, torchvision
import numpy as np

def dict_to_tuple(d):
    return d["boxes"], d["scores"], d["labels"]

class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input):
        outputs = self.model(input)
        return dict_to_tuple(outputs[0])

model = torchvision.models.detection.fasterrcnn_resnet50_fpn
model = TraceWrapper(model(pretrained=True))
model.eval()

batch_size = 1
channel = 3
img_width = 1024
img_height = 1024

input = torch.Tensor(np.random.uniform(0.0, 250.0, size=(batch_size, channel, img_height, img_width)))

with torch.no_grad():  
    scripted_model = torch.jit.trace(model, input)
    scripted_model.eval()

shape_list = [("input0", torch.Size([batch_size, channel, img_width, img_height]))]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

Here’s the error message I got at line “torch.jit.trace”:

Traceback (most recent call last):
  File "/home/plg/yyz_workspace/tvm_DOTA/pass_reset_input.py", line 31, in <module>
    scripted_model = torch.jit.trace(model, input)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
           ^^^^^^^^^^^^^
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 1084, in trace_module
    _check_trace(
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 562, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
        Graph diff:
                  graph(%self.1 : __torch__.TraceWrapper,
                        %input.1 : Tensor):
                    %model : __torch__.torchvision.models.detection.faster_rcnn.FasterRCNN = prim::GetAttr[name="model"](%self.1)
                    %8 : float = prim::Constant[value=0.03125](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
                    %9 : float = prim::Constant[value=0.0625](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
                    %10 : float = prim::Constant[value=0.125](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
                ......  // Thousands of lines omitted
                +   %2948 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%2947, %2921, %2922)
                ?      ^^                                                      ^^      ^     ^^
                -   %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::TupleUnpack(%2956)
                ?                                                                ^^
                +   %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::TupleUnpack(%2948)
                ?                                                                ^^
                    %7 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%4, %5, %6)
                    return (%7)
        First diverging operator:
        Node diff:
                - %model : __torch__.torchvision.models.detection.faster_rcnn.FasterRCNN = prim::GetAttr[name="model"](%self.1)
                + %model : __torch__.torchvision.models.detection.faster_rcnn.___torch_mangle_355.FasterRCNN = prim::GetAttr[name="model"](%self.1)
                ?                                                             ++++++++++++++++++++

What’s the error about and what should I do to solve the problem? Thanks for your help!

1 Like

The issue seems to originate from TorchScript rather than TVM itself. Were you able to figure out a solution?

You’re right. It comes from TorchScript. I’ve posted the issue in the community of pytorch but no helpful solution yet.

faster_rcnn_from_pytorch.ipynb (Open in Colab)

I was able to hack together a working version of Faster-R-CNN based on the Mask-R-CNN example in TVM repo. Not so sure how our workflows are different and why I was able to compile successfully—please update once you find out. Let me know if you have any issues! (Note: the code is really messy at the moment.)

It seems that we’re using different versions of pytorch and torchvision. What version of pytorch and torchvision are you using? torch 1.11.10 and torchvision 0.12.0 for me.