Does Relay support GPUs as targets to compile PyTorch Models?

Hi,

I am exploring the capabilities of tvm for accelerating PyTorch models.

In this tutorial from tvm’s official page, I have read the following:

Compile with Relay VM

Note: Currently only CPU target is supported. For x86 target, it is highly recommended to build TVM with Intel MKL and Intel OpenMP to get best performance, due to the existence of large dense operator in torchvision rcnn models.

Does that mean that it is not yet possible to achieve PyTorch model acceleration for GPUs via tvm?

@masahi has worked on that and got some results.

1 Like

Sure of course we support PyTorch models on GPU. But for the particular case of object detection models, due to some limitations in Relay and other challenges associated with dynamic models, currently we cannot go any faster than PyTorch.

1 Like

That’s great! Currently I am only interested in testing the API, to me it is totally fine to have comparable results to plain PyTorch models.

I would love to learn more about how I can get my models compiled on tvm and running on nvidia GPUs.

I have been trying to compile MiDaS on GPUs like V100 and RTX2080 earlier today. I can’t say that it did not work at all, but the output of the compiled model was always filled with zeros. So I wanted to test my script with a simpler network but got the same results.

Below is my simple test case and the results that I get:

The basic model

class SimpleNet(torch.nn.Module):
    def __init__(self):
        """
        A simple model to explore the behaviour of the tvm compiler
        """
        super(SimpleNet, self).__init__()
        self.in_layer = torch.nn.Conv2d(3, 10, 3, padding = 1)
        self.out_layer = torch.nn.Conv2d(10, 1, 3, padding = 1)
        self.relu = torch.nn.ReLU(inplace=True)
        
        self.init_weights()

    def forward(self, x):
        l1 = self.relu(self.in_layer(x))
        out = self.relu(self.out_layer(l1))
        return out
    
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.ones_(m.weight)

One of the configurations I have tried out

model = SimpleNet().cuda().eval()
script_model = torch.jit.trace(model, img.cuda()).eval()

input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(script_model, shape_list)

target = tvm.target.cuda(model="nvidia/teslaV100", options=['-keys=cuda','-arch=sm_80','-max_num_threads=1024', '-thread_warp_size=32'])

target_host = 'llvm'
ctx = tvm.gpu()

with tvm.transform.PassContext(opt_level=3):
        graph, lib, params = tvm.relay.build(mod,
                                     target=target,
                                     target_host=target_host,
                                     params=params)
m = tvm.contrib.graph_runtime.create(graph, lib, ctx)

m.set_input(input_name, tvm.nd.array(img, ctx))
m.run()
tvm_output = m.get_output(0)

The result

tvm_output.asnumpy()

    array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]]], dtype=float32)

Kindly note that the results of pytorch and its jit-compiled version are as expected.

The warnings

After the compilation, I had seen the following warnings about how the compiler could not find the right config. But since I got the same warnings at cpu side and did not get the blank output there, I have ignored them. Performance of the model is pretty close to PyTorch on CPU in terms of speed as well.

    Cannot find config for target=cuda -keys=cuda,gpu -arch=sm_80 -max_num_threads=1024 -model=nvidia/teslaV100 -thread_warp_size=32, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 3, 224, 224), 'float32'), ('TENSOR', (10, 3, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
Cannot find config for target=cuda -keys=cuda,gpu -arch=sm_80 -max_num_threads=1024 -model=nvidia/teslaV100 -thread_warp_size=32, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 10, 224, 224), 'float32'), ('TENSOR', (1, 10, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:24: DeprecationWarning: legacy graph runtime behavior of producing json / lib / params will be removed in the next release. Please see documents of tvm.contrib.graph_runtime.GraphModule for the  new recommended usage.

The result should match pytorch. Even for complicated models like MaskRCNN, we can get the same output as PyTorch.

You need to set params as well. m.set_input(**params)

1 Like

Worked like a charm! :sparkles:

Thank you very much.