Hi. I’m trying to compile PyTorch FasterRCNN model and import the graph to Relay. Here’s the code.
import tvm
from tvm import relay
import torch, torchvision
import numpy as np
def dict_to_tuple(d):
return d["boxes"], d["scores"], d["labels"]
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input):
outputs = self.model(input)
return dict_to_tuple(outputs[0])
model = torchvision.models.detection.fasterrcnn_resnet50_fpn
model = TraceWrapper(model(pretrained=True))
model.eval()
batch_size = 1
channel = 3
img_width = 1024
img_height = 1024
input = torch.Tensor(np.random.uniform(0.0, 250.0, size=(batch_size, channel, img_height, img_width)))
with torch.no_grad():
scripted_model = torch.jit.trace(model, input)
scripted_model.eval()
shape_list = [("input0", torch.Size([batch_size, channel, img_width, img_height]))]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
Here’s the error message I got at line “torch.jit.trace”:
Traceback (most recent call last):
File "/home/plg/yyz_workspace/tvm_DOTA/pass_reset_input.py", line 31, in <module>
scripted_model = torch.jit.trace(model, input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 794, in trace
return trace_module(
^^^^^^^^^^^^^
File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 1084, in trace_module
_check_trace(
File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 562, in _check_trace
raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
Graph diff:
graph(%self.1 : __torch__.TraceWrapper,
%input.1 : Tensor):
%model : __torch__.torchvision.models.detection.faster_rcnn.FasterRCNN = prim::GetAttr[name="model"](%self.1)
%8 : float = prim::Constant[value=0.03125](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
%9 : float = prim::Constant[value=0.0625](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
%10 : float = prim::Constant[value=0.125](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
...... // Thousands of lines omitted
+ %2948 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%2947, %2921, %2922)
? ^^ ^^ ^ ^^
- %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::TupleUnpack(%2956)
? ^^
+ %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::TupleUnpack(%2948)
? ^^
%7 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%4, %5, %6)
return (%7)
First diverging operator:
Node diff:
- %model : __torch__.torchvision.models.detection.faster_rcnn.FasterRCNN = prim::GetAttr[name="model"](%self.1)
+ %model : __torch__.torchvision.models.detection.faster_rcnn.___torch_mangle_355.FasterRCNN = prim::GetAttr[name="model"](%self.1)
? ++++++++++++++++++++
What’s the error about and what should I do to solve the problem? Thanks for your help!