Error from Compiling and Running RetinaNet from TorchVision

Hi:

We are trying to deploy a RetainNet from TorchVision. We use the codes from deploy_object_detection_pytorch tutorial, with only changing from

model_func = torchvision.models.detection.maskrcnn_resnet50_fpn

to

model_func = torchvision.models.detection.retinanet_resnet50_fpn

There are several errors occurs:

  1. During the relay.vm.compile part, a segmentation fault occurs, which is odd, since mask_rcnn is larger than retina_net and compiling mask_rcnn works fine. if ulimit -s 65535 is executed in terminal, the relay.vm.compile can work fine eventually

  2. Even we get the results from relay.vm.compile after hacking the stack size, it takes an unusual long time for the VirtualMachine.run. And the results from VirtualMachine.run are quite different from the Pytorch results. For example, the VirtualMachine.run generates 117 bboxes while the Pytorch generates 472 bboxes

Can you share your script?

Sure

This is my test script. By the way, you may need to cp the test image into the same directory of this script

import tvm
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download
import os
import numpy as np
import cv2

# PyTorch imports
import torch
import torchvision

import logging

in_size = 300

def get_test_img():
    img_path = './street_small.jpg'
    img = cv2.imread(img_path).astype("float32")
    img = cv2.resize(img, (in_size, in_size))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.transpose(img / 255.0, [2, 0, 1])
    img = np.expand_dims(img, axis=0)
    return img


def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]


def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace


class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])



def get_script_module():
    inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
    model_func = torchvision.models.detection.retinanet_resnet50_fpn
    model = TraceWrapper(model_func(pretrained=True))
    model.eval()

    with torch.no_grad():
        out = model(inp)
        script_module = do_trace(model, inp)

    return script_module


def retina_net_lab():

    img = get_test_img()
    script_module = get_script_module()

    input_name = "input0"
    input_shape = (1, 3, in_size, in_size)
    shape_list = [(input_name, input_shape)]
    mod, params = relay.frontend.from_pytorch(script_module, shape_list)

     
    target = "llvm"
    with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
        vm_exec = relay.vm.compile(mod, target=target, params=params)


    ctx = tvm.cpu()
    vm = VirtualMachine(vm_exec, ctx)
    vm.set_input("main", **{input_name: img})
    tvm_res = vm.run()

    score_threshold = 0.9
    boxes = tvm_res[0].asnumpy().tolist()
    valid_boxes = []
    for i, score in enumerate(tvm_res[1].asnumpy().tolist()):
        if score > score_threshold:
            valid_boxes.append(boxes[i])
        else:
            break

    print("Get {} valid boxes".format(len(valid_boxes)))


if __name__ == '__main__':
    retina_net_lab()

I got this error when I ran your script, have you seen this?

  File "retinanet_test.py", line 97, in <module>
    retina_net_lab()
  File "retinanet_test.py", line 71, in retina_net_lab
    mod, params = relay.frontend.from_pytorch(script_module, shape_list)
  File "/home/masa/projects/dev/tvm/python/tvm/relay/frontend/pytorch.py", line 3041, in from_pytorch
    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
  File "/home/masa/projects/dev/tvm/python/tvm/relay/frontend/pytorch.py", line 2463, in convert_operators
    inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype)
  File "/home/masa/projects/dev/tvm/python/tvm/relay/frontend/pytorch.py", line 1868, in nms
    data, score_threshold=-1.0, id_index=-1, score_index=0
  File "/home/masa/projects/dev/tvm/python/tvm/relay/op/vision/nms.py", line 54, in get_valid_counts
    _make.get_valid_counts(data, score_threshold, id_index, score_index), 3
  File "/home/masa/projects/dev/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(TVMFuncCall+0x69) [0x7f9b4e76afb9]
  [bt] (2) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, double, int, int)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, double, int, int)>(tvm::RelayExpr (*)(tvm::RelayExpr, double, int, int))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x1fb) [0x7f9b4e3b7b8b]
  [bt] (1) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(tvm::runtime::TVMPODValue_::operator double() const+0x179) [0x7f9b4dc27dc9]
  [bt] (0) /mnt/2e797a66-fd2b-44fc-a3ba-24d7d65f2780/projects/dev/tvm/build_llvm8/libtvm.so(+0x78185b) [0x7f9b4dc1e85b]
  File "/home/masa/projects/dev/tvm/include/tvm/runtime/packed_func.h", line 372
TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------

  Check failed: type_code_ == kDLFloat (8 vs. 2) :  expected float but get Object

I have not seen this before. The Pytorch version is 1.7.0, the TorchVision version is 0.8.1, and the TVM repo I am using is sync with upstream yesterday.

Besides, would you mind to change

model_func = torchvision.models.detection.retinanet_resnet50_fpn

in the def get_script_module() function to

model_func = torchvision.models.detection.maskrcnn_resnet50_fpn

Then running this script can check whether the environment is correct, since it will be as same as the mask_rcnn tutorial

Sorry I was on a different branch, I can reproduce the segfault. It seems segfault is happening at relay DeDup pass. Something is definitely wrong, I’ll take a look.

cc @kevinthesun This is a new detection model from torchvision 0.8 (PyTorch 1.7).

The segfault can be passed by after executing command ulimit -s 65535 in terminal, which seems to be quite strange since mask_rcnn is a larger model and works fine.

And the error seems to be caused by

detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)

of the RetinaNet

Yes I tried ulimit and for my case I got zero valid box for some reason. Compile time is also extremely long so I’ll first take a look at that.

What do you mean by this? Which error?

I mean in the def forward() function of class RetinaNet(object), if we return before

detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes),

the TVM works fine, and produce the same results as Pytorch.

Therefore, I believe something in the def postprocess_detections cause the segfault, extremely long compile time, and zero valid box output.

Yeah that function has a large loop that gets unrolled 91 times during tracing, so the relay model becomes very large.

It seems PR https://github.com/pytorch/vision/pull/2828 improved that function and also made the number of loop iteration smaller. But unfortunately that commit is not part of v0.8 release.

It seems that they change nms for every class into a batched_nms, which is also used in the post_process of mask_rcnn.

Even though, if we ignore the super long compile and execution time, shouldn’t TVM gives the same results whether batched_nms is used or not?

Yes, the result should be (almost) identical. Something is definitely wrong.