lsy643
September 25, 2020, 10:19am
1
Hi:
I am trying to use relay.frontend.from_tensorflow to convert a SSD model from Tensorflow Object Models V1
This is the url of the inference_graph_def
It seems that the errors come from argwhere
and where
operators.
Many errors like these below occurs
[10:02:08] /root/Codes/lsy_tvm/src/printer/doc.cc:55: text node: 'e[31m an internal invariant was violated while typechecking your program [10:02:08] /root/Codes/lsy_tvm/src/tir/ir/expr.cc:647: Check failed: lanes > 1 (0 vs. 1) :
Stack trace:
[bt] (0) /root/Codes/lsy_tvm/build/libtvm.so(+0x109c4c2) [0x7ff7a31ef4c2]
[bt] (1) /root/Codes/lsy_tvm/build/libtvm.so(tvm::tir::Broadcast::Broadcast(tvm::PrimExpr, int)+0x12a) [0x7ff7a31f39aa]
[bt] (2) /root/Codes/lsy_tvm/build/libtvm.so(tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&)+0x445) [0x7ff7a323df55]
[bt] (3) /root/Codes/lsy_tvm/build/libtvm.so(tvm::operator-(tvm::PrimExpr, tvm::PrimExpr)+0x32) [0x7ff7a323e6b2]
[bt] (4) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::TypeSolver::Reporter::AssertEQ(tvm::PrimExpr const&, tvm::PrimExpr const&)+0x4b) [0x7ff7a36d6fcb]
[bt] (5) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::WhereRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)+0x527) [0x7ff7a35e9a67]
[bt] (6) /root/Codes/lsy_tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x63b) [0x7ff7a303a4cb]
[bt] (7) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::TypeSolver::Solve()+0x380) [0x7ff7a36d3fc0]
[bt] (8) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::TypeInferencer::Infer(tvm::RelayExpr)+0x55) [0x7ff7a383fcb5]
1 Like
Did you solve this problem? I have same issue
Can you provide a script showing how you compiled the model?
lsy643
October 14, 2020, 3:13am
4
@tiandiao123
Not yet. The where
problem is caused by the WhereRel
in src/relay/op/tensor/transform.cc
, you can fix it by changing like this to deal with the dynamic shape:
for (size_t i = 0; i < x_shape.size(); i++) {
if(x_shape[i].as<AnyNode>() || y_shape[i].as<AnyNode>()) {**
continue;
}
CHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
if (i < cond_shape.size()) {
CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
<< "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
But there are other problems.
lsy643
October 14, 2020, 3:22am
5
@kevinthesun
Sure.
The pb of ssd_mobilenet_v1 I used. I upload it in Google Drive, and you may download it by clicking it.
The script to compile the model
import tvm
from tvm import relay
import tensorflow as tf
BOXES_NAME = 'detection_boxes'
CLASSES_NAME = 'detection_classes'
SCORES_NAME = 'detection_scores'
NUM_DETECTIONS_NAME = 'num_detections'
out_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]
def load_graph_def(graph_def_path):
infer_graph_def = tf.GraphDef()
with open(graph_def_path, 'rb') as f:
infer_graph_def.ParseFromString(f.read())
return infer_graph_def
def convert_ssd_mobilenet_v1_coco():
model_name = 'ssd_mobilenet_v1_coco'
graph_def_path = './frozen_inference_graph.pb'
graph_def = load_graph_def(graph_def_path)
mod, params = relay.frontend.from_tensorflow(
graph_def, layout='NCHW', outputs=out_names
)
if __name__ == '__main__':
convert_ssd_mobilenet_v1_coco()
lsy643
October 14, 2020, 3:31am
6
@kevinthesun
If you are interested. The model is downloaded from TensorFlow 1 Detection Model Zoo .
Besides, I freeze the model with Tensorflow == 1.14,
since nms_v5
is not supported by Tensorflow frontend
yet.
The way to freeze the model most comes from the def build_model
function form this file , which requires the installation of Object Detection API with TensorFlow 1
The script I used to freeze the model
import tvm
import tensorflow as tf
import os
import subprocess
import shutil
from google.protobuf import text_format
from object_detection import exporter
from object_detection.protos import pipeline_pb2, image_resizer_pb2
BOXES_NAME = 'detection_boxes'
CLASSES_NAME = 'detection_classes'
SCORES_NAME = 'detection_scores'
NUM_DETECTIONS_NAME = 'num_detections'
out_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]
def remove_node(graph_def, node):
for n in graph_def.node:
if node.name in n.input:
n.input.remove(node.name)
ctrl_name = '^' + node.name
if ctrl_name in n.input:
n.input.remove(ctrl_name)
graph_def.node.remove(node)
def remove_op(graph_def, op_name):
matches = [node for node in graph_def.node if node.op == op_name]
for match in matches:
remove_node(graph_def, match)
def remove_assert(frozen_graph):
remove_op(frozen_graph, 'Assert')
return frozen_graph
def build_fronzen_model(model_dir):
model_name = 'ssd_mobilenet_v1_coco'
config_path = os.path.join(model_dir, 'pipeline.config')
checkpoint_path = os.path.join(model_dir, 'model.ckpt')
# load config from file
config = pipeline_pb2.TrainEvalPipelineConfig()
with open(config_path, 'r') as f:
text_format.Merge(f.read(), config, allow_unknown_extension=True)
#config.model.ssd.num_classes = 4
#print(config)
input_height = config.model.ssd.image_resizer.fixed_shape_resizer.height
input_width = config.model.ssd.image_resizer.fixed_shape_resizer.width
input_shape = [1, input_height, input_width, 3]
#return
tmp_dir='.optimize_model_tmp_dir'
if os.path.exists(tmp_dir):
subprocess.call(['rm', '-rf', tmp_dir])
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
with tf.Session(config=tf_config):
with tf.Graph().as_default():
exporter.export_inference_graph(
'image_tensor',
config,
checkpoint_path,
tmp_dir,
input_shape=input_shape)
# read frozen graph from file
frozen_graph_path = os.path.join(tmp_dir, 'frozen_inference_graph.pb')
frozen_graph = tf.GraphDef()
with open(frozen_graph_path, 'rb') as f:
frozen_graph.ParseFromString(f.read())
subprocess.call(['rm', '-rf', tmp_dir])
for node in frozen_graph.node:
if node.op == 'Placeholder':
print(node)
frozen_graph = remove_assert(frozen_graph)
return frozen_graph
lsy643
October 15, 2020, 1:54am
8
@kevinthesun
I changed my script to set input shape, and the same problem occurs
import tvm
from tvm import relay
import tensorflow as tf
def load_graph_def(graph_def_path):
infer_graph_def = tf.GraphDef()
with open(graph_def_path, 'rb') as f:
infer_graph_def.ParseFromString(f.read())
return infer_graph_def
def convert_ssd_mobilenet_v1_coco():
model_name = 'ssd_mobilenet_v1_coco'
graph_def_path = './frozen_inference_graph.pb'
graph_def = load_graph_def(graph_def_path)
input_node = ['image_tensor']
input_data = [np.random.uniform(0.0, 255.0, size=(1, 300, 300, 3)).astype("uint8")]
shape_dict = {
e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)
}
mod, params = relay.frontend.from_tensorflow(
graph_def, layout='NCHW', shape=shape_dict, outputs=out_names
)
if __name__ == '__main__':
convert_ssd_mobilenet_v1_coco()
Besides, the input shape is already fixed when the graph_def is frozen
try:
input_name = "image_tensor"
input_shape = (1, 512, 512, 3)
outputs=['detection_boxes', "detection_scores", "detection_classes"]
graph_def = TFParser(model_path, outputs).parse()
mod, params = relay.frontend.from_tensorflow(graph_def, shape={input_name: input_shape}, outputs=outputs)
desired_layouts = {'nn.conv2d': ['NCHW', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout(desired_layouts)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
This should work for tf od official model zoo, but there is no guarantee if you generate model using other method.
lsy643
October 16, 2020, 3:09am
10
@kevinthesun
The frozen_inference_graph.pb
directly extracted from official ssd_mobilenet_v1.tar.tar.gz
file can be parsed correctly by relay.frontend.from_tensorflow
.
However, if the frozen_inference_graph.pb
is generated from the checkpoints of the same ssd_mobilenet_v1.tar.tar.gz
by exporter.export_inference_graph
, then the relay.frontend.from_tensorflow
will fail.
Comparing to the official frozen_inference_graph.pb
, the frozen_inference_graph.pb
generated by
exporter.export_inference_graph
does not have control flow nodes like LoopCond
, ‘Merge’ and so on.
Would you give some suggestions to make the frozen_inference_graph.pb
generated by
exporter.export_inference_graph
work too? Since the models to be deployed in production normally generated through methods like exporter.export_inference_graph
by our engineers .