Errors from relay.frontend.from_tensorflow an ssd_mobilnet_net

Hi:

I am trying to use relay.frontend.from_tensorflow to convert a SSD model from Tensorflow Object Models V1

This is the url of the inference_graph_def

It seems that the errors come from argwhere and where operators.

Many errors like these below occurs

[10:02:08] /root/Codes/lsy_tvm/src/printer/doc.cc:55: text node: 'e[31m an internal invariant was violated while typechecking your program [10:02:08] /root/Codes/lsy_tvm/src/tir/ir/expr.cc:647: Check failed: lanes > 1 (0 vs. 1) : 
Stack trace:
  [bt] (0) /root/Codes/lsy_tvm/build/libtvm.so(+0x109c4c2) [0x7ff7a31ef4c2]
  [bt] (1) /root/Codes/lsy_tvm/build/libtvm.so(tvm::tir::Broadcast::Broadcast(tvm::PrimExpr, int)+0x12a) [0x7ff7a31f39aa]
  [bt] (2) /root/Codes/lsy_tvm/build/libtvm.so(tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&)+0x445) [0x7ff7a323df55]
  [bt] (3) /root/Codes/lsy_tvm/build/libtvm.so(tvm::operator-(tvm::PrimExpr, tvm::PrimExpr)+0x32) [0x7ff7a323e6b2]
  [bt] (4) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::TypeSolver::Reporter::AssertEQ(tvm::PrimExpr const&, tvm::PrimExpr const&)+0x4b) [0x7ff7a36d6fcb]
  [bt] (5) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::WhereRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)+0x527) [0x7ff7a35e9a67]
  [bt] (6) /root/Codes/lsy_tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x63b) [0x7ff7a303a4cb]
  [bt] (7) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::TypeSolver::Solve()+0x380) [0x7ff7a36d3fc0]
  [bt] (8) /root/Codes/lsy_tvm/build/libtvm.so(tvm::relay::TypeInferencer::Infer(tvm::RelayExpr)+0x55) [0x7ff7a383fcb5]

1 Like

Did you solve this problem? I have same issue

Can you provide a script showing how you compiled the model?

@tiandiao123 Not yet. The where problem is caused by the WhereRel in src/relay/op/tensor/transform.cc, you can fix it by changing like this to deal with the dynamic shape:

for (size_t i = 0; i < x_shape.size(); i++) {
    if(x_shape[i].as<AnyNode>() || y_shape[i].as<AnyNode>()) {**
      continue;
    }

    CHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
        << "x and y must have the same shape: " << x_shape << " vs " << y_shape;

    if (i < cond_shape.size()) {
      CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
          << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
    }
  }

But there are other problems.

@kevinthesun Sure. The pb of ssd_mobilenet_v1 I used. I upload it in Google Drive, and you may download it by clicking it.

The script to compile the model

import tvm 
from tvm import relay
import tensorflow as tf

BOXES_NAME = 'detection_boxes'
CLASSES_NAME = 'detection_classes'
SCORES_NAME = 'detection_scores'
NUM_DETECTIONS_NAME = 'num_detections'
out_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]


def load_graph_def(graph_def_path):
    infer_graph_def = tf.GraphDef()
    with open(graph_def_path, 'rb') as f:
        infer_graph_def.ParseFromString(f.read())  
    return infer_graph_def


def convert_ssd_mobilenet_v1_coco():
    model_name = 'ssd_mobilenet_v1_coco'

    graph_def_path = './frozen_inference_graph.pb'
    graph_def = load_graph_def(graph_def_path)

    mod, params = relay.frontend.from_tensorflow(
        graph_def, layout='NCHW', outputs=out_names
    )

if __name__ == '__main__':
    convert_ssd_mobilenet_v1_coco()

@kevinthesun If you are interested. The model is downloaded from TensorFlow 1 Detection Model Zoo.

Besides, I freeze the model with Tensorflow == 1.14, since nms_v5 is not supported by Tensorflow frontend yet.

The way to freeze the model most comes from the def build_model function form this file , which requires the installation of Object Detection API with TensorFlow 1

The script I used to freeze the model

import tvm
import tensorflow as tf
import os
import subprocess

import shutil
from google.protobuf import text_format
from object_detection import exporter
from object_detection.protos import pipeline_pb2, image_resizer_pb2


BOXES_NAME = 'detection_boxes'
CLASSES_NAME = 'detection_classes'
SCORES_NAME = 'detection_scores'
NUM_DETECTIONS_NAME = 'num_detections'
out_names = [BOXES_NAME, CLASSES_NAME, SCORES_NAME, NUM_DETECTIONS_NAME]


def remove_node(graph_def, node):
    for n in graph_def.node:
        if node.name in n.input:
            n.input.remove(node.name)
        ctrl_name = '^' + node.name
        if ctrl_name in n.input:
            n.input.remove(ctrl_name)
    graph_def.node.remove(node)


def remove_op(graph_def, op_name):
    matches = [node for node in graph_def.node if node.op == op_name]
    for match in matches:
        remove_node(graph_def, match)



def remove_assert(frozen_graph):
    remove_op(frozen_graph, 'Assert')
    return frozen_graph


def build_fronzen_model(model_dir):
    model_name = 'ssd_mobilenet_v1_coco'
    config_path = os.path.join(model_dir, 'pipeline.config')
    checkpoint_path = os.path.join(model_dir, 'model.ckpt')

    # load config from file
    config = pipeline_pb2.TrainEvalPipelineConfig()
    with open(config_path, 'r') as f:
        text_format.Merge(f.read(), config, allow_unknown_extension=True)
    #config.model.ssd.num_classes = 4
    #print(config)
    
    input_height = config.model.ssd.image_resizer.fixed_shape_resizer.height
    input_width = config.model.ssd.image_resizer.fixed_shape_resizer.width
    input_shape = [1, input_height, input_width, 3]
    #return 

    tmp_dir='.optimize_model_tmp_dir'
    if os.path.exists(tmp_dir):
        subprocess.call(['rm', '-rf', tmp_dir])

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config):
        with tf.Graph().as_default():
            exporter.export_inference_graph(
                'image_tensor',
                config,
                checkpoint_path,
                tmp_dir,
                input_shape=input_shape)

    # read frozen graph from file
    frozen_graph_path = os.path.join(tmp_dir, 'frozen_inference_graph.pb')
    frozen_graph = tf.GraphDef()
    with open(frozen_graph_path, 'rb') as f:
        frozen_graph.ParseFromString(f.read())

    subprocess.call(['rm', '-rf', tmp_dir])

    for node in frozen_graph.node:
        if node.op == 'Placeholder':
            print(node)

    frozen_graph = remove_assert(frozen_graph)

    return frozen_graph


You need to set input shape. You can check https://github.com/apache/incubator-tvm/blob/main/tests/python/frontend/tensorflow/test_forward.py#L2985

@kevinthesun

I changed my script to set input shape, and the same problem occurs

import tvm 
from tvm import relay
import tensorflow as tf

def load_graph_def(graph_def_path):
    infer_graph_def = tf.GraphDef()
    with open(graph_def_path, 'rb') as f:
        infer_graph_def.ParseFromString(f.read())  
    return infer_graph_def


def convert_ssd_mobilenet_v1_coco():
    model_name = 'ssd_mobilenet_v1_coco'

    graph_def_path = './frozen_inference_graph.pb'
    graph_def = load_graph_def(graph_def_path)

    input_node = ['image_tensor']
    input_data = [np.random.uniform(0.0, 255.0, size=(1, 300, 300, 3)).astype("uint8")]
    shape_dict = {
        e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)
    }

    mod, params = relay.frontend.from_tensorflow(
        graph_def, layout='NCHW', shape=shape_dict, outputs=out_names
    )

if __name__ == '__main__':
    convert_ssd_mobilenet_v1_coco()

Besides, the input shape is already fixed when the graph_def is frozen

try:

input_name = "image_tensor"
input_shape = (1, 512, 512, 3)
outputs=['detection_boxes', "detection_scores", "detection_classes"]
graph_def = TFParser(model_path, outputs).parse()
mod, params = relay.frontend.from_tensorflow(graph_def, shape={input_name: input_shape}, outputs=outputs)

desired_layouts = {'nn.conv2d': ['NCHW', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                                    relay.transform.ConvertLayout(desired_layouts)])
with tvm.transform.PassContext(opt_level=3):
    mod = seq(mod)

This should work for tf od official model zoo, but there is no guarantee if you generate model using other method.

@kevinthesun

The frozen_inference_graph.pb directly extracted from official ssd_mobilenet_v1.tar.tar.gz file can be parsed correctly by relay.frontend.from_tensorflow.

However, if the frozen_inference_graph.pb is generated from the checkpoints of the same ssd_mobilenet_v1.tar.tar.gzby exporter.export_inference_graph, then the relay.frontend.from_tensorflow will fail.

Comparing to the official frozen_inference_graph.pb, the frozen_inference_graph.pb generated by exporter.export_inference_graph does not have control flow nodes like LoopCond, ‘Merge’ and so on.

Would you give some suggestions to make the frozen_inference_graph.pb generated by exporter.export_inference_graph work too? Since the models to be deployed in production normally generated through methods like exporter.export_inference_graph by our engineers .