When I load pb model from tensorflow, however it raise error as following:
Traceback (most recent call last):
File "east_tvm.py", line 60, in <module>
graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params)
File "/root/miniconda3/envs/tensorpack/lib/python3.6/site-packages/tvm-0.6.dev0-py3.6-linux-x86_64.egg/tvm/relay/build_module.py", line 304, in build
graph_json, lowered_funcs, params = graph_gen.codegen(func)
I have try to get param which comes from pb model, there is nothing about negative shape.
Here is my code:
# coding=utf-8
from tvm import relay
import tensorflow as tf
import sys
import tvm.relay.testing.tf as tf_testing
import cv2
import os
import tvm
import numpy as np
model_path = sys.argv[1]
# load tensorflow model
output_nodes = [
'feature_fusion/Conv_7/Sigmoid',
'feature_fusion/Conv_8/Sigmoid',
'feature_fusion/Conv_9/Sigmoid'
]
with tf.gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# Add shapes to the graph.
with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
output_nodes,
)
nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
from tvm.contrib.download import download_testdata
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
target = 'llvm'
target_host = 'llvm'
layout = 'NHWC'
ctx = tvm.cpu(0)
# Test image
image = np.ones((512, 512, 3))
x = np.expand_dims(image[:, :, ::-1], axis=0)
shape_dict = {'input_images': x.shape}
print("start to relay build")
sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=output_nodes)
print("opt...")
for item in params:
print(params[item].shape)
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params)
BTW, I have run well on macbook, but when I transfer it to centos, it raise above error