Description:
When batch_size != 1 and target =‘cuda’ the script will crash at the statement m.run()
.
Note that:
the script can run well when batch_size =1 or target = ‘llvm’.
The runnable script:
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
import torch
import torchvision.models as models
from torchvision import transforms
import numpy as np
from PIL import Image
model = models.mobilenet_v2(pretrained=True).eval()
batch_size = 2 # crash when batch_size != 1
input_shape = [batch_size, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
from tvm.contrib.download import download_testdata
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
for i in range(batch_size-1):
img = np.append(img, img, axis=0)
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
target = 'cuda'
dev = tvm.gpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
dtype = "float32"
m = graph_runtime.GraphModule(lib["default"](dev))
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
m.run() # crash here!!!
tvm_output = m.get_output(0).asnumpy()