Here is an simple dynamic LSTM example that throws an error. It looks like the first handful of errors can be resolved by using _infer_value_simulated
that was recently added. After that, I am seeing a different error.
Let me know if this works for you. I think this would be a great scenario to support in TVM. Thanks a lot for the help 
import tvm
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from tvm import relay
from tvm.contrib import graph_runtime
pb_name = "lstm.pb"
hidden_dim = 4
shape_0 = (1,4,4) # input (batch, seq, input)
shape_1 = (1,) # seq len
shape_dict = {"Placeholder" : shape_0, "Placeholder_1" : shape_1}
input_0_np = np.random.random(size=shape_0)
input_1_np = np.random.randint(1, 3, size=shape_1)
def _remove_assert(all_nodes):
all_nodes_dict = {}
for node in all_nodes:
all_nodes_dict[node.name] = node
new_nodes = []
for i,node in enumerate(all_nodes):
if "assert" in node.name.lower():
continue
new_inputs = []
for inp in node.input:
if "assert" in inp.lower():
continue
else:
new_inputs.append(inp)
del node.input[:]
node.input.extend(new_inputs)
new_nodes.append(node)
graph_def = graph_pb2.GraphDef()
graph_def.node.extend(new_nodes)
return graph_def
def create_pb():
with tf.Graph().as_default() as graph:
x = tf.placeholder(tf.float32, shape=shape_0)
y = tf.placeholder(tf.int32, shape=shape_1)
lstm_cell = tf.nn.rnn_cell.LSTMCell(hidden_dim)
output, (c_state, h_state) = tf.nn.dynamic_rnn(lstm_cell, x, y, dtype=tf.float32)
output_add = tf.add(output, output)
c_state_add = tf.add(c_state, c_state)
h_state_add = tf.add(h_state, h_state)
with tf.gfile.GFile(pb_name, "wb") as f:
sess = tf.Session(graph = graph)
sess.run(tf.global_variables_initializer())
graph_def = graph.as_graph_def(add_shapes=True)
graph_def = _remove_assert(graph_def.node)
graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, ["Add", "Add_1", "Add_2"])
graph_def = TransformGraph(
graph_def, # graph def
["Placeholder", "Placeholder_1"], # inputs
["Add", "Add_1", "Add_2"], # outputs
["strip_unused_nodes",
"sort_by_execution_order",
"fold_batch_norms",
"sort_by_execution_order",
"fold_old_batch_norms",
"sort_by_execution_order",
]# transforms
)
f.write(graph_def.SerializeToString())
def get_graph():
with tf.gfile.GFile(pb_name, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name = "")
return graph_def, graph
def run_tf(graph):
with tf.Session(graph = graph) as sess:
output_tensor_0 = tf.get_default_graph().get_tensor_by_name("Add" + ":0")
output_tensor_1 = tf.get_default_graph().get_tensor_by_name("Add_1" + ":0")
output_tensor_2 = tf.get_default_graph().get_tensor_by_name("Add_2" + ":0")
placeholder_tensor = tf.get_default_graph().get_tensor_by_name("Placeholder:0")
placeholder_1_tensor = tf.get_default_graph().get_tensor_by_name("Placeholder_1:0")
output = sess.run([output_tensor_0, output_tensor_1, output_tensor_2], { placeholder_tensor : input_0_np, placeholder_1_tensor : input_1_np })
def run_tvm(graph_def):
print("Before importing...")
sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict, outputs=["Add", "Add_1", "Add_2"])
print("Finished from tensorflow")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target="llvm -mcpu=core-avx2", params=params)
m = graph_runtime.create(graph, lib, tvm.cpu())
m.set_input("Placeholder", input_0_np)
m.set_input("Placeholder_1", input_1_np)
m.set_input(**params)
m.run()
tvm_output_0=m.get_output(0)
tvm_output_1=m.get_output(1)
tvm_output_2=m.get_output(2)
create_pb()
graph_def, graph = get_graph()
run_tf(graph)
run_tvm(graph_def)