Annotation error with Vitis-AI in TVM

To be able to target the Vitis-AI edge DPUCZDX8G-zcu104 target, I need to compile the model on the host side and generate the TVM for edge_ lib.so , After importing a convolutional neural network model using the usual Relay API’s, annotate the Relay expression for the given Vitis-AI DPU target and partition the graph. My code and Error are as follows :

code:

input_size = 416

original_image = cv2.imread(‘horses.jpg’)

original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)

original_image_size = original_image.shape[:2]

image_data = image_preporcess(np.copy(original_image), [input_size, input_size])

image_data = image_data[np.newaxis, …]

image_data = image_data.astype(np.float32)

CHECKPOINT = ‘./saved/’

layout = “NCHW”

data_shape = image_data.shape

shape_dict = {‘input/input_data’: image_data.shape}

dtype_dict = {‘input/input_data’: ‘uint8’}

return_elements = [“pred_sbbox/concat_2:0”, “pred_mbbox/concat_2:0”, “pred_lbbox/concat_2:0”]

parser = TFParser(CHECKPOINT)

graph_def = parser.parse()

mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=return_elements)

#category_file = os.path.join(CHECKPOINT, ‘category.txt’)

classes = read_class_names(‘category.txt’)

num_classes = len(classes)

net = mod[“main”]

mod = tvm.IRModule.from_expr(net)

tvm_target = ‘llvm’

target =‘DPUCZDX8G-zcu104’

mod[“main”] = bind_params_by_name(mod[“main”], params)

mod = annotation(mod, params, target)

mod = relay.transform.MergeCompilerRegions()(mod)

mod = relay.transform.PartitionGraph()(mod)

Errors are reported during this process, as shown below:

WARNING:root:Attribute T is ignored in relay.sym.concatenate

WARNING:root:Attribute _node_name is ignored in relay.sym.concatenate

WARNING:root:Attribute _target_layout is ignored in relay.sym.concatenate

Traceback (most recent call last):

File “run.py”, line 280, in mod = annotation(mod, params, target)

File “/workspace/tvm/python/tvm/relay/op/contrib/vitis_ai.py”, line 88, in annotation xgraph = pyxir.frontend.tvm.from_relay(mod, params, postprocessing=None)

File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay.py”, line 58, in from_relay cvx_preprocessing=cvx_preprocessing

… … File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xlayer_registry.py”, line 139, in __base_relay_2_xlayer X = specific_relay_2_xlayer(op_name, expr, iXs)

File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py”, line 355, in relay_op ty = expr.checked_type

File “/workspace/tvm/python/tvm/ir/expr.py”, line 50, in checked_type raise ValueError(“The type checker has not populated” " the checked_type for this node")

ValueError: The type checker has not populated the checked_type for this node

Can you give me a complete instance of using Vitis AI in TVM ?

@jtuyls @mak could you guys take a look?

@williamyang4978 Can you try adding mod = relay.transform.InferType()(mod) after mod["main"] = bind_params_by_name(mod["main"], params)? For some models we require shape information which gets added with this call. This is missing in the documentation so we will update that and/or take care of this internally.

I revised the python script as you suggested, but there are still errors. The errors are as follows:

Traceback (most recent call last): File “yolov3_tvm_host.py”, line 281, in mod = annotation(mod, params, target) File “/workspace/tvm/python/tvm/relay/op/contrib/vitis_ai.py”, line 88, in annotation xgraph = pyxir.frontend.tvm.from_relay(mod, params, postprocessing=None) File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay.py”, line 58, in from_relay cvx_preprocessing=cvx_preprocessing … …

File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py”, line 105, in call RELAY_2_XLAYER, **kwargs) File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l1_basic.py”, line 379, in concatenate relay_id=relay_idx) File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/graph/layer/xlayer_factory.py”, line 113, in factory_wrapper X = XLAYER_FACTORY[xop_name](*args, **kwargs) File “/workspace/.local/lib/python3.6/site-packages/pyxir-0.1.3-py3.6-linux-x86_64.egg/pyxir/graph/ops/l1_basic_nn.py”, line 164, in concat len(set([il.shapes[i] for il in input_layers])) == 1 AssertionError

Can you give a complete example of deployment on zcu104?

@williamyang4978 You can follow edge setup from here.

Is this model public? If yes, could you share the link of the model. We would like to replicate the problem you were seeing.

Please download my experiment script, network model and parameters from the link below: https://github.com/williamyang4978/yolov3_fpga_tvm_test.

I used resnet_v2_101 model to do experiments and encountered the same annotation problem as the last experiment . Please see path https://github.com/williamyang4978/resnet_v2_101_tvm_fpga_ZCU104.git for detailed log information and model .

@williamyang4978 We are looking at the issue you are seeing and will get back to you.

@williamyang4978 If you use pyxir version 0.1.5 your YoloV3 model should work. The issue was that the mul + max = leaky relu pattern wasn’t recognized. I also added complete host and board example scripts for compiling and running ResNet 18 here for your reference: pyxir/examples/tvm at master · Xilinx/pyxir · GitHub.

I implemented yolov3 according to the example of resnet18 provided by you, but the following error occurred:

Traceback (most recent call last): File “yolov3_tvm_host.py”, line 370, in InferenceSession.run() File “/workspace/tvm/python/tvm/contrib/graph_runtime.py”, line 207, in run self._run() File “tvm/_ffi/_cython/./packed_func.pxi”, line 322, in tvm._ffi._cy3.core.PackedFuncBase.call File “tvm/_ffi/_cython/./packed_func.pxi”, line 257, in tvm._ffi._cy3.core.FuncCall File “tvm/_ffi/_cython/./packed_func.pxi”, line 246, in tvm._ffi._cy3.core.FuncCall3 File “tvm/_ffi/_cython/./base.pxi”, line 160, in tvm._ffi._cy3.core.CALL

ValueError: ValueError: Missing inputs. The provided inputs are: dict_keys([‘vitis_ai_0_i0’]), but the expected inputs are: {‘vitis_ai_0_i0’, ‘vitis_ai_0_i2’} At: /workspace/.local/lib/python3.6/site-packages/pyxir-0.1.5-py3.6-linux-x86_64.egg/pyxir/runtime/tensorflow/runtime_tf.py(244): run /workspace/.local/lib/python3.6/site-packages/pyxir-0.1.5-py3.6-linux-x86_64.egg/pyxir/base.py(515): rt_func /workspace/.local/lib/python3.6/site-packages/pyxir-0.1.5-py3.6-linux-x86_64.egg/pyxir/opaque_func.py(113): opaque_func_wrapper /workspace/tvm/python/tvm/contrib/graph_runtime.py(207): run yolov3_tvm_host.py(370):

For edge deployment, on the host side, for my processing script, please refer to the address:

I’d like to ask you if the mistake I found above is a TVM problem? Is there a solution?

@jtuyls I am trying to compile deeplabv3 model(Resnet 101 backbone) from pytorch. Getting same assertion error while using annotation API. I am using master branch of pyxir and tvm. AnnotateTarget is different from annotation? (While using AnnotateTarget not seeing this error). Any suggestions?

(vitis-ai-pytorch) Vitis-AI /workspace/python/compile > python3 compile_pytorch_deeplab.py
2021-03-18 04:18:25.490123: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/xilinx/xrt/lib:/usr/lib:/usr/lib/x86_64-linux-gnu:/usr/local/lib:/opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib
2021-03-18 04:18:25.490159: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/quantization/decent_quantizer.py:50: UserWarning: Could not import decent_q module. Please check if installed.
  warnings.warn("Could not import decent_q module. Please check"
File /home/vitis-ai-user/.tvm_test_data/data/cat.png exists, skip.
input img size (224, 224)
transform_image_torchvision torch.Size([3, 224, 224])
(1, 3, 224, 224) <class 'tuple'>
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
INFO:pyxir:
**************************************************
* RELAY IR TO PYXIR
**************************************************
/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l10_temporary.py:64: UserWarning: Convert Relay Adaptive Avg pool2d layer into normal average pool2d layer
  warnings.warn("Convert Relay Adaptive Avg pool2d layer into normal"
Traceback (most recent call last):
  File "compile_pytorch_deeplab.py", line 179, in <module>
    mod = annotation(mod, params, target)
  File "/workspace/python/tvm/relay/op/contrib/vitis_ai.py", line 92, in annotation
    xgraph = pyxir.frontend.tvm.from_relay(mod, params, postprocessing=None)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay.py", line 58, in from_relay
    cvx_preprocessing=cvx_preprocessing
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xgraph_converter.py", line 96, in from_relay_to_xgraph
    cvx_prep=cvx_preprocessing)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 75, in function
    op_idx, RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 184, in tuple_expr
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xlayer_registry.py", line 122, in __base_relay_2_xlayer
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xlayer_registry.py", line 122, in __base_relay_2_xlayer
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l1_basic.py", line 78, in add
    **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l2_convolution.py", line 216, in nn_conv2d
    **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xlayer_registry.py", line 122, in __base_relay_2_xlayer
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 239, in tuple_get_item
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l1_basic.py", line 239, in nn_batch_norm
    **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l2_convolution.py", line 216, in nn_conv2d
    **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 239, in tuple_get_item
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xlayer_registry.py", line 122, in __base_relay_2_xlayer
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xlayer_registry.py", line 122, in __base_relay_2_xlayer
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 239, in tuple_get_item
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l1_basic.py", line 239, in nn_batch_norm
    **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l2_convolution.py", line 216, in nn_conv2d
    **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_2_xlayer_registry.py", line 122, in __base_relay_2_xlayer
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l0_expr_and_others.py", line 117, in call
    RELAY_2_XLAYER, **kwargs)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/frontend/tvm/relay_tools/relay_l1_basic.py", line 414, in concatenate
    X = px.ops.concat(op_name, data_layers, axis, relay_id=relay_idx)
  File "/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/graph/ops/l1_basic_nn.py", line 167, in concat
    assert i == axis or len(check) == 1
**AssertionError**

@williamyang4978 It worked for me if you set the layout to ‘NHWC’ here: yolov3_fpga_tvm_test/yolov3_tvm_host.py at 7242a889aa889b0895cb5c4f06f988c7710bbf5d · williamyang4978/yolov3_fpga_tvm_test · GitHub (as it is NHWC) and this line: yolov3_fpga_tvm_test/yolov3_tvm_host.py at 7242a889aa889b0895cb5c4f06f988c7710bbf5d · williamyang4978/yolov3_fpga_tvm_test · GitHub should be desired_layouts = {'nn.conv2d': ['NHWC', 'default']}. PS you can make use of exporting and loading the Relay module to avoid going through the from_tensorflow frontend again and again:

json_file = "relay_mod_yolov3.json"
params_file = "relay_params_yolov3.params"

# Uncomment following lines to save the Relay module
# with open(json_file, 'w') as f:
#     f.write(tvm.ir.save_json(mod))
# with open(params_file, "wb") as fo:
#     fo.write(relay.save_param_dict(params))

with open(json_file, 'rb') as f:
    mod = tvm.ir.load_json(f.read())
with open(params_file, "rb") as f:
    params = relay.load_param_dict(bytearray(f.read()))

@abdulazizm I expect this to be a different issue. The line that fails checks whether the dimensions that are not concatenated over have the same size for all inputs. Could you share the Relay representation of the model (print(mod['main']) ) after doing an InferType pass (mod = tvm.transform.InferType()(mod) )? That way I can see what the input layers are for this concatenation.

It didn’t work for me if I set the layout to ‘NHWC’,it reported error:‘Segmentation fault (core dumped)’

/workspace/yolov3/yolov3_tvm_host.py(325)() → mod = relay.transform.PartitionGraph()(mod) (Pdb) n Segmentation fault (core dumped)

@williamyang4978

If you have upsampling layer in the model then you can register upsampling layout transform. Below is the code to register upsampling layout transform.

from tvm.relay import reg
@reg.register_convert_op_layout("nn.upsampling")
def convert_upsampling(attrs, inputs, tinfos, desired_layouts):
    data = inputs
    new_attrs = dict(attrs)
    new_attrs['layout'] = 'NHWC'
    return relay.nn.upsampling(data[0], **new_attrs)

Add upsampling in desired layout

desired_layouts = {'nn.conv2d': ['NHWC', 'default'], 'nn.upsampling': ['NHWC']}

@mak I modified the code according to your suggestion, but it still reported ‘Segmentation fault’. But if I set the layout to ‘NCHW’,don’t set the layout to ‘NHWC’, then the following errors will be reported: Traceback (most recent call last): File “./yolov3_tvm_host.py”, line 406, in InferenceSession.run() File “/workspace/tvm/python/tvm/contrib/graph_runtime.py”, line 209, in run self._run() File “tvm/_ffi/_cython/./packed_func.pxi”, line 322, in tvm._ffi._cy3.core.PackedFuncBase.call File “tvm/_ffi/_cython/./packed_func.pxi”, line 257, in tvm._ffi._cy3.core.FuncCall File “tvm/_ffi/_cython/./packed_func.pxi”, line 246, in tvm._ffi._cy3.core.FuncCall3 File “tvm/_ffi/_cython/./base.pxi”, line 160, in tvm._ffi._cy3.core.CALL ValueError: ValueError: Missing inputs. The provided inputs are: dict_keys([‘vitis_ai_0_i0’]), but the expected inputs are: {‘vitis_ai_0_i2’, ‘vitis_ai_0_i0’} At: /workspace/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/runtime/tensorflow/runtime_tf.py(244): run /workspace/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/base.py(516): rt_func /workspace/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/opaque_func.py(113): opaque_func_wrapper /workspace/tvm/python/tvm/contrib/graph_runtime.py(209): run ./yolov3_tvm_host.py(406):

@williamyang4978 Could you try this script? If it doesn’t work for you, please specify the TVM and PyXIR versions you are using.

# To be able to target the Vitis-AI cloud DPUCZDX8G target we first have to import the target in PyXIR. i
# This PyXIR package is the interface being used by TVM to integrate with the Vitis-AI stack. Additionaly, i
# import the typical TVM and Relay modules and the Vitis-AI contrib module inside TVM.
import os
import sys
import numpy as np
from pathlib import Path


import pyxir
import pyxir.contrib.target.DPUCZDX8G

import tvm
import tvm.relay as relay
from tvm import contrib
from tvm.relay import transform
from tvm.contrib import utils, graph_runtime
from tvm.contrib.target import vitis_ai
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.op.contrib.vitis_ai import annotation
from tvm.relay.frontend.tensorflow_parser import TFParser

import colorsys
import random
import cv2
from PIL import Image

# After importing a convolutional neural network model using the usual Relay API's, 
# annotate the Relay expression for the given Vitis-AI DPU target and partition the graph.
#from tensorflow.contrib import decent_q

def read_class_names(class_file_name):
    '''loads class name from a file'''
    names = {}
    with open(class_file_name, 'r') as data:
        for ID, name in enumerate(data):
            names[ID] = name.strip('\n')
    return names


def draw_bbox(image, bboxes, classes, show_label=True):
    """
    bboxes: [x_min, y_min, x_max, y_max, probability, cls_id] format coordinates.
    """

    num_classes = len(classes)
    image_h, image_w, _ = image.shape
    hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)]
    colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
    colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))

    random.seed(0)
    random.shuffle(colors)
    random.seed(None)

    for i, bbox in enumerate(bboxes):
        coor = np.array(bbox[:4], dtype=np.int32)
        fontScale = 0.5
        score = bbox[4]
        class_ind = int(bbox[5])
        bbox_color = colors[class_ind]
        bbox_thick = int(0.6 * (image_h + image_w) / 600)
        c1, c2 = (coor[0], coor[1]), (coor[2], coor[3])
        cv2.rectangle(image, c1, c2, bbox_color, bbox_thick)

        if show_label:
            bbox_mess = '%s: %.2f' % (classes[class_ind], score)
            t_size = cv2.getTextSize(bbox_mess, 0, fontScale, thickness=bbox_thick // 2)[0]
            cv2.rectangle(image, c1, (c1[0] + t_size[0], c1[1] - t_size[1] - 3), bbox_color, -1)  # filled

            cv2.putText(image, bbox_mess, (c1[0], c1[1] - 2), cv2.FONT_HERSHEY_SIMPLEX,
                        fontScale, (0, 0, 0), bbox_thick // 2, lineType=cv2.LINE_AA)
    return image


def bboxes_iou(boxes1, boxes2):
    boxes1 = np.array(boxes1)
    boxes2 = np.array(boxes2)

    boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
    boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])

    left_up = np.maximum(boxes1[..., :2], boxes2[..., :2])
    right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:])

    inter_section = np.maximum(right_down - left_up, 0.0)
    inter_area = inter_section[..., 0] * inter_section[..., 1]
    union_area = boxes1_area + boxes2_area - inter_area
    ious = np.maximum(1.0 * inter_area / union_area, np.finfo(np.float32).eps)

    return ious


def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):
    """
    :param bboxes: (xmin, ymin, xmax, ymax, score, class)
    Note: soft-nms, https://arxiv.org/pdf/1704.04503.pdf
          https://github.com/bharatsingh430/soft-nms
    """
    classes_in_img = list(set(bboxes[:, 5]))
    best_bboxes = []

    for cls in classes_in_img:
        cls_mask = (bboxes[:, 5] == cls)
        cls_bboxes = bboxes[cls_mask]

        while len(cls_bboxes) > 0:
            max_ind = np.argmax(cls_bboxes[:, 4])
            best_bbox = cls_bboxes[max_ind]
            best_bboxes.append(best_bbox)
            cls_bboxes = np.concatenate([cls_bboxes[: max_ind], cls_bboxes[max_ind + 1:]])
            iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4])
            weight = np.ones((len(iou),), dtype=np.float32)

            assert method in ['nms', 'soft-nms']

            if method == 'nms':
                iou_mask = iou > iou_threshold
                weight[iou_mask] = 0.0

            if method == 'soft-nms':
                weight = np.exp(-(1.0 * iou ** 2 / sigma))

            cls_bboxes[:, 4] = cls_bboxes[:, 4] * weight
            score_mask = cls_bboxes[:, 4] > 0.
            cls_bboxes = cls_bboxes[score_mask]

    return best_bboxes


def postprocess_boxes(pred_bbox, org_img_shape, input_size, score_threshold):
    valid_scale = [0, np.inf]
    pred_bbox = np.array(pred_bbox)

    pred_xywh = pred_bbox[:, 0:4]
    pred_conf = pred_bbox[:, 4]
    pred_prob = pred_bbox[:, 5:]

    # # (1) (x, y, w, h) --> (xmin, ymin, xmax, ymax)
    pred_coor = np.concatenate([pred_xywh[:, :2] - pred_xywh[:, 2:] * 0.5,
                                pred_xywh[:, :2] + pred_xywh[:, 2:] * 0.5], axis=-1)
    # # (2) (xmin, ymin, xmax, ymax) -> (xmin_org, ymin_org, xmax_org, ymax_org)
    org_h, org_w = org_img_shape
    resize_ratio = min(input_size / org_w, input_size / org_h)

    dw = (input_size - resize_ratio * org_w) / 2
    dh = (input_size - resize_ratio * org_h) / 2

    pred_coor[:, 0::2] = 1.0 * (pred_coor[:, 0::2] - dw) / resize_ratio
    pred_coor[:, 1::2] = 1.0 * (pred_coor[:, 1::2] - dh) / resize_ratio

    # # (3) clip some boxes those are out of range
    pred_coor = np.concatenate([np.maximum(pred_coor[:, :2], [0, 0]),
                                np.minimum(pred_coor[:, 2:], [org_w - 1, org_h - 1])], axis=-1)
    invalid_mask = np.logical_or((pred_coor[:, 0] > pred_coor[:, 2]), (pred_coor[:, 1] > pred_coor[:, 3]))
    pred_coor[invalid_mask] = 0

    # # (4) discard some invalid boxes
    bboxes_scale = np.sqrt(np.multiply.reduce(pred_coor[:, 2:4] - pred_coor[:, 0:2], axis=-1))
    scale_mask = np.logical_and((valid_scale[0] < bboxes_scale), (bboxes_scale < valid_scale[1]))

    # # (5) discard some boxes with low scores
    classes = np.argmax(pred_prob, axis=-1)
    scores = pred_conf * pred_prob[np.arange(len(pred_coor)), classes]
    score_mask = scores > score_threshold
    mask = np.logical_and(scale_mask, score_mask)
    coors, scores, classes = pred_coor[mask], scores[mask], classes[mask]

    return np.concatenate([coors, scores[:, np.newaxis], classes[:, np.newaxis]], axis=-1)

def image_preporcess(image, target_size, gt_boxes=None):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)

    ih, iw = target_size
    h, w, _ = image.shape

    scale = min(iw / w, ih / h)
    nw, nh = int(scale * w), int(scale * h)
    image_resized = cv2.resize(image, (nw, nh))

    image_paded = np.full(shape=[ih, iw, 3], fill_value=128.0)
    dw, dh = (iw - nw) // 2, (ih - nh) // 2
    image_paded[dh:nh + dh, dw:nw + dw, :] = image_resized
    image_paded = image_paded / 255.

    if gt_boxes is None:
        return image_paded

    else:
        gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] * scale + dw
        gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] * scale + dh
        return image_paded, gt_boxes


def inputs_func(img_files):
    """Utility function to read images from a list"""
    inputs = []
    input_size = 416
    dtype = 'float32'
    for img_path in img_files:
        original_image = cv2.imread(img_path)
        #print(img_path)
        original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        image_data = image_preporcess(np.copy(original_image), [input_size, input_size])
        image_data = image_data[np.newaxis, ...]
        image_data = image_data.astype(np.float32)
        image_data = tvm.nd.array(image_data.astype(dtype))

        inputs.append(image_data)
    return inputs



input_size = 416
original_image = cv2.imread('horses.jpg')
#original_image = cv2.imread('../voc/000002.jpg')
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
original_image_size = original_image.shape[:2]
image_data = image_preporcess(np.copy(original_image), [input_size, input_size])
image_data = image_data[np.newaxis, ...]
image_data = image_data.astype(np.float32)

CHECKPOINT = './saved/'
#CHECKPOINT = '../.tvm_test_data/tf/YoloV3/yolov3_coco.pb'
layout = "NHWC"

data_shape = image_data.shape
input_name = 'input/input_data'
#shape_dict = {'input/input_data': image_data.shape}
#dtype_dict = {'input/input_data': 'uint8'}
shape_dict = {input_name: data_shape}
dtype_dict = {input_name: 'uint8'}

return_elements = ["pred_sbbox/concat_2:0", "pred_mbbox/concat_2:0", "pred_lbbox/concat_2:0"]
parser = TFParser(CHECKPOINT)
graph_def = parser.parse()
print("---------------------------front of from_tensorflow----------------") 
mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict,
                                             outputs=return_elements)

#json_file = "relay_mod_yolov3.json"
#params_file = "relay_params_yolov3.params"
# with open(json_file, 'w') as f:
#     f.write(tvm.ir.save_json(mod))
# with open(params_file, "wb") as fo:
#     fo.write(relay.save_param_dict(params))

# with open(json_file, 'rb') as f:
#     mod = tvm.ir.load_json(f.read())
# with open(params_file, "rb") as f:
#     params = relay.load_param_dict(bytearray(f.read()))
# mod = relay.transform.InferType()(mod)

print("---------------------------after of from_tensorflow----------------") 

classes = read_class_names('category.txt')
num_classes = len(classes)

tvm_target = 'llvm'
dpu_target ='DPUCZDX8G-zcu104'


#-----------------------------------------------------------------------------------
# For the edge target we recommend converting the layout to NHWC for best performance
desired_layouts = {'nn.conv2d': ['NHWC', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts),
                                relay.transform.FoldConstant()])
with tvm.transform.PassContext(opt_level=3):
     mod = seq(mod)

mod["main"] = bind_params_by_name(mod["main"], params)
mod = annotation(mod, params, dpu_target)
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.PartitionGraph()(mod)

# Convert convolutions that won't be executed on DPU back to NCHW
desired_layouts = {'nn.conv2d': ['NCHW', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts),
                                relay.transform.FoldConstant()])
with tvm.transform.PassContext(opt_level=3):
     mod = seq(mod)

#-----------------------------------------------------------------------------------


# build the TVM runtime library for executing the model
export_rt_mod_file = os.path.join(os.getcwd(), 'vitis_ai.rtmod')
with tvm.transform.PassContext(opt_level=3, 
                               config= {'relay.ext.vitis_ai.options.target': dpu_target,
                                        'relay.ext.vitis_ai.options.export_runtime_module': export_rt_mod_file}):
     lib = relay.build(mod, tvm_target, params=params)
    
    
InferenceSession = graph_runtime.GraphModule(lib["default"](tvm.cpu()))
px_quant_size = int(os.environ['PX_QUANT_SIZE']) \
                if 'PX_QUANT_SIZE' in os.environ else 128

print("Quantize on first {} inputs".format(px_quant_size))

for i in range(px_quant_size):
    InferenceSession.set_input('input/input_data', tvm.nd.array(image_data.astype(np.float32)))
    InferenceSession.run()


# get outputs
print(InferenceSession.get_output(0).shape)
print(InferenceSession.get_output(1).shape)
print(InferenceSession.get_output(2).shape)


pred_sbbox = InferenceSession.get_output(0).asnumpy()
pred_mbbox = InferenceSession.get_output(1).asnumpy()
pred_lbbox = InferenceSession.get_output(2).asnumpy()

pred_bbox = np.concatenate([np.reshape(pred_sbbox, (-1, 5 + num_classes)),
                            np.reshape(pred_mbbox, (-1, 5 + num_classes)),
                            np.reshape(pred_lbbox, (-1, 5 + num_classes))], axis=0)

bboxes = postprocess_boxes(pred_bbox, original_image_size, input_size, 0.3)
bboxes = nms(bboxes, 0.45, method='nms')
image = draw_bbox(original_image, bboxes, classes)
image = Image.fromarray(image)

first_name, last_name = 'horses.jpg'.split(".")
predict_image_name = first_name + "_predict." + last_name
image.save(predict_image_name)


# Save the TVM lib module so that the Vitis-AI runtime module will also be exported (to the 'export_runtime_module' path we previously passed as a config).
temp = utils.tempdir()
lib.export_library("tvm_lib.so")

# Export lib for aarch64 target
tvm_target = tvm.target.arm_cpu('ultra96')
lib_kwargs = {
     'fcompile': contrib.cc.create_shared,
     'cc': "/usr/aarch64-linux-gnu/bin/ld"
}

with tvm.transform.PassContext(opt_level=3,
                               config={'relay.ext.vitis_ai.options.load_runtime_module': export_rt_mod_file}):
     lib_edge_dpu = relay.build(mod, target=tvm_target, params=params)

lib_edge_dpu.export_library('tvm_dpu_arm.so', **lib_kwargs)


1 Like

TVM=0.8 and PyXIR=0.1.6

After I reinstalled the latest TVM and pyxir, the following error occurred :

/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/runtime/init.py:34: UserWarning: Could not load cpu-tf runtime because of error: No module named ‘tensorflow’ .format(e))

/home/vitis-ai-user/.local/lib/python3.6/site-packages/pyxir-0.1.6-py3.6-linux-x86_64.egg/pyxir/quantization/decent_quantizer.py:53: UserWarning: Could not import decent_q module. Please check if installed.

warnings.warn(“Could not import decent_q module. Please check” -------------data_shape: (1, 416, 416, 3) Traceback (most recent call last): File “yolov3_tvm_host.py”, line 298, in parser = TFParser(CHECKPOINT) File “/workspace/tvm/python/tvm/relay/frontend/tensorflow_parser.py”, line 46, in init from tensorflow.core.framework import graph_pb2 ModuleNotFoundError: No module named ‘tensorflow’

@williamyang4978 We can actually no need to worry of UserWarning I guess and it seems that tensorflow module is not installed. My docker setup with tvm and pyxir seems to have tensorflow. Have you followed the same steps that is there in Vitis-AI Integration — tvm 0.8.dev0 documentation (or you can install tensorflow module in the docker/target device)

FYI: For me, I got the same error (but that was thrown as User Warning for me) while trying to use tvm.relay module on target EDGE device (ZCU104) on top of PYNQ. @jtuyls suggested to ignore this error if we just have to run inference on EDGE device (means for inference we no need of relay). If really worried about this warning on Target device, we can install tensorflow from this wheel package (Built for python 3.6/AARCH64 arm 64bit - my PYNQ device was configured for this version) → https://github.com/itdaniher/aarch64-tensorflow/releases/download/v1.5.0/tensorflow-1.5.0-cp36-cp36m-linux_aarch64.whl