@tqchen @Ravenwater I just modified the TensorFlow graph to use uint8 as data type but still it doesn’t run. The same error messages appear.
In short, the step that I import the graph and compile it with nnvm looks like this:
with tf.gfile.GFile('graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
sym, params = nnvm.frontend.from_tensorflow(graph_def) # <----------
shape_dict = {'input': (1, height, width, 3)}
dtype_dict = {'input': 'uint8'}
with vta.build_config():
graph, lib, params = nnvm.compiler.build(graph=sym, # <---------
shape=shape_dict,
dtype=dtype_dict,
target=target,
params=params,
target_host=target_host)
...
m = graph_runtime.create(graph, lib, ctx) # <------------------------
I believe the problem is either:
- I didn’t call the APIs properly.
- It might be because the graph fed to
nnvm.frontend.from_tensorflow()is not compatible. But that is questionable to me because the graph is very simple. I attached the graph picture below.
