I’ve tried out TVM with several ONNX Model Zoo models, but surprisingly many don’t work:
ok - https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz
ok - https://github.com/onnx/models/raw/master/vision/classification/mnist/model/mnist-8.tar.gz
ok - https://github.com/onnx/models/raw/master/vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.tar.gz
ok - https://github.com/onnx/models/raw/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.tar.gz
not ok - https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.tar.gz
not ok - https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/model/yolov4.tar.gz
not ok - https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz
not ok - https://github.com/onnx/models/raw/master/text/machine_comprehension/roberta/model/roberta-base-11.tar.gz
not ok - https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.tar.gz
I’ve fillled a GitHub issue, so I could attach files:.
opened 02:03PM - 17 Jun 21 UTC
closed 09:46AM - 18 Jun 21 UTC
I'm filing an issue before posting on the forum because GitHub allows me to atta… ch the script to reproduce the issue.
When I try to use TVM with the models from https://github.com/onnx/models many fail:
```
$ python3 tvm_onnx_model_zoo.py
[...]
Summary:
ok - https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz
ok - https://github.com/onnx/models/raw/master/vision/classification/mnist/model/mnist-8.tar.gz
ok - https://github.com/onnx/models/raw/master/vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.tar.gz
ok - https://github.com/onnx/models/raw/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.tar.gz
not ok - https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.tar.gz
not ok - https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/model/yolov4.tar.gz
not ok - https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz
not ok - https://github.com/onnx/models/raw/master/text/machine_comprehension/roberta/model/roberta-base-11.tar.gz
not ok - https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.tar.gz
```
This is the full output:
[tvm_onnx_model_zoo.log](https://github.com/apache/tvm/files/6670776/tvm_onnx_model_zoo.log)
And this is the Python script I used:
```python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# See:
# - https://tvm.apache.org/docs/tutorials/frontend/from_onnx.html
# - https://github.com/apache/tvm/blob/main/tutorials/frontend/from_onnx.py
# - https://github.com/onnx/models
import subprocess
import os
import sys
import posixpath
from six.moves.urllib.request import urlretrieve
import glob
import onnx
from onnx import numpy_helper
import numpy as np
import tvm
import tvm.relay as relay
from tvm.contrib import graph_executor
def get_value_info_shape(value_info):
return tuple([max(d.dim_value, 1) for d in value_info.type.tensor_type.shape.dim])
urls = [
'https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz',
'https://github.com/onnx/models/raw/master/vision/classification/mnist/model/mnist-8.tar.gz',
'https://github.com/onnx/models/raw/master/vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.tar.gz',
'https://github.com/onnx/models/raw/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.tar.gz',
'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.tar.gz',
'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/model/yolov4.tar.gz',
'https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz',
'https://github.com/onnx/models/raw/master/text/machine_comprehension/roberta/model/roberta-base-11.tar.gz',
# XXX: Often segfaults
'https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.tar.gz',
]
target = "cuda"
ctx = tvm.device(target, 0)
summary = []
for url in urls:
print(f'==> {url} <==')
archive = posixpath.basename(url)
if not os.path.exists(archive):
print(f'Downloading {url} ...')
urlretrieve(url, archive)
assert os.path.exists(archive)
import tarfile
tar = tarfile.open(archive, 'r:gz')
for n in tar.getnames():
if n.endswith('.onnx'):
model_file = n
name = os.path.dirname(n)
break
if not os.path.exists(model_file):
print(f'Extracting {archive} ...')
#subprocess.call(['tar', 'xzf', archive])
tar.extractall()
assert os.path.exists(model_file)
print(f'Loading {model_file} ...')
onnx_model = onnx.load(model_file)
graph = onnx_model.graph
initializers = set()
for initializer in graph.initializer:
initializers.add(initializer.name)
input_values = []
test_data_set = glob.glob(os.path.join(name, 'test_data_set_*'))[0]
shape_dict = {}
assert os.path.exists(test_data_set)
for input in graph.input:
if input.name not in initializers:
i = len(input_values)
input_data = os.path.join(test_data_set, f'input_{i}.pb')
tensor = onnx.TensorProto()
input_data = open(input_data, 'rb').read()
tensor.ParseFromString(input_data)
x = numpy_helper.to_array(tensor)
input_values.append(x)
shape_dict[input.name] = x.shape
print(f'Input shapes: {shape_dict}')
try:
print(f'Importing graph from ONNX to TVM Relay IR ...')
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
print(f'Compiling graph from Relay IR to {target} ...')
with tvm.transform.PassContext(opt_level=1):
lib = relay.build(mod, target=target, params=params)
m = lib["default"](ctx)
module = graph_executor.GraphModule(m)
for i, input_value in enumerate(input_values):
input_value = tvm.nd.array(input_value, ctx)
module.set_input(i, input_value)
print(f"Running inference...")
module.run()
output0 = module.get_output(0)
output0 = output0.asnumpy()
except KeyboardInterrupt:
raise
except Exception as ex:
print(f'Caught an exception {ex}')
result = 'not ok'
else:
print(f'Succeeded!')
result = 'ok'
summary.append((result, url))
print()
print('Summary:')
for result, url in summary:
print(f'{result}\t- {url}')
```
This was with:
* latest TVM (commit 1fac10b359dec1bd6ad45ce36541a882aaba586b)
* Ubuntu 20.04.2 LTS
* NVIDIA Tesla T4, Driver Version 460.80, CUDA Version 11.2
Am I doing something wrong? Are these known TVM limitations? Something else?
I’d appreciate if folks could confirm whether I’m doing something wrong here, or whether these results faithfully reflect the current state of ONNX support in TVM.
Thanks.
masahi
June 17, 2021, 8:23pm
2
I think models other than MaskRCNN should work. Two advices from me:
People often use mod = relay.transform.DynamicToStatic()(mod)
after ONNX import. I don’t know exactly when it is required.
relay.build
and graph_executor
only work on static modules (no dynamic shape, no control flow). For dynamic models such as MaskRCNN and yolo4, you need to use VM compiler and runtime. Example: torchscript-to-tvm/yolo5_test.py at master · masahi/torchscript-to-tvm · GitHub
I tried two changes above in your script but it still doesn’t work, however. cc @mbrookhart @jwfromm
Explicit use of DynamicToStatic should only really be needed if we’re autotuning, and then only in some cases. You should probably freeze the parameters of the onnx model, the TF and pytorch exporters end up storing shape information as weights in the onnx model, freezing the parameters tends to make for a more robust import.
I’ll poke around your script, give me a bit.
This change to the core compilation/execution step of your script:
try:
print(f'Importing graph from ONNX to TVM Relay IR ...')
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
print(f'Compiling graph from Relay IR to {target} ...')
ex = relay.create_executor("vm", mod=mod, device=ctx, target=target)
print(f"Running inference...")
output = ex.evaluate()(*input_values, **params)
Gets YoloV4, Bertsquad, and GPT-2 running.
We don’t do very well with this MaskRCNN model, it has dynamically shaped convolutions in it that TVM doesn’t handle very well yet.
Roberta looks like an input datatype problem, I’ll see if I can fix quickly.
masahi
June 17, 2021, 9:09pm
5
I can confirm that freeze_params=True
did the trick and made yolo4 and bert-squad working. GPT-2 seems to work with or without freeze_params
.
My script with modification: onnx_zoo_test.py · GitHub
Sorry for the delay, my afternoon was packed with meetings. The issue with Roberta is that somehow the int64 input tensor is getting loaded as float64.
If I make this hacky change to to the way you’re importing tensors, it works:
tensor = onnx.TensorProto()
input_data = open(input_data, 'rb').read()
tensor.ParseFromString(input_data)
x = numpy_helper.to_array(tensor)
if "roberta" in url:
x = x.astype("int64")
input_values.append(x)
shape_dict[input.name] = x.shape
I’m not sure why ONNX’s numpy_helper isn’t getting that datatype right, the values are definitely Integer, just cast to float64.
Anyway, of this list, that just leaves MaskRCNN, which we know is a limitation. I think @ziheng has been working on better dynamic kernel generation for TVM, but I don’t know the current status.
I also have problems running SSD-MobileNetV1.
It looks like Conv2d cannot handle the dynamic shape problem
File "/home/chlu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
rv = local_pyfunc(*pyargs)
File "/home/chlu/tvm/python/tvm/relay/op/strategy/generic.py", line 240, in _compute_conv2d
return [topi_compute(*args)]
File "/home/chlu/tvm/python/tvm/topi/x86/conv2d.py", line 129, in conv2d_nchw
packed_out = conv2d_NCHWc(data, kernel, strides, padding, dilation, layout, layout, out_dtype)
File "/home/chlu/tvm/python/tvm/autotvm/task/topi_integration.py", line 165, in wrapper
node = topi_compute(cfg, *args)
File "/home/chlu/tvm/python/tvm/topi/x86/conv2d.py", line 191, in conv2d_NCHWc
oh = (ih - kernel_height + pt + pb) // sh + 1
TypeError: unsupported operand type(s) for -: 'Any' and 'int'
@masahi @mbrookhart , thank you for your prompt replies. Indeed passing freeze_params=True
allowed me to get all models except Mask-RCNN working.
Roberta worked fine here without further changes. I’m using onnx==1.8.1. Maybe the onnx version you’re using got a broken numpy_helper.to_array
implementation.
Regarding the different ways of executing graphs, I was following apps/benchmark/gpu_imagenet_bench.py as a blueprint, since it gave the best performance when I last tried. Where can I learn more about this VM compiler?
AFAICT, tests/python/frontend/onnx/test_forward.py
are unit tests, and tests/python/contrib/test_bnns/test_onnx_topologies.py
only covers a sliver of the ONNX models, doesn’t test the expected outputs (and couldn’t get to work without errors.) Are the ONNX Model Zoo models part of TVM regressions tests in some other way?
SSD-MobileNetV1 worked fine with my script after modifying it as:
--- tvm_onnx_model_zoo.py.orig 2021-06-18 15:11:25.140135692 +0000
+++ tvm_onnx_model_zoo.py 2021-06-18 15:12:00.945439030 +0000
@@ -46,6 +46,7 @@
'https://github.com/onnx/models/raw/master/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.tar.gz',
'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.tar.gz',
'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/model/yolov4.tar.gz',
+ 'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.tar.gz',
'https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz',
'https://github.com/onnx/models/raw/master/text/machine_comprehension/roberta/model/roberta-base-11.tar.gz',
# XXX: Often segfaults
@@ -109,7 +110,7 @@
try:
print(f'Importing graph from ONNX to TVM Relay IR ...')
- mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
+ mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
print(f'Compiling graph from Relay IR to {target} ...')
with tvm.transform.PassContext(opt_level=1):
@jfonseca Thank you for your test and reply but I forgot to say that I use vm and the target is llvm.
the script like @masahi used.
1 Like