AnnotateTarget pass error for tf model

I want to use tf model to test bring-your-own-codegen.

mod = transform.AnnotateTarget("dnnl")(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)

The mod is from tf model. but something went wrong

  File "/Automation/hwl/2020/mlperf_inference/mlperf/inference-master/v0.5/classification_and_detection/python/tvm_acc_main_fp32_cpu_wda_fail.py", line 548, in main
    mod = transform.AnnotateTarget("dnnl")(mod)

  File "/home/wda/incubator-tvm-20200407/python/tvm/ir/transform.py", line 141, in __call__
    return _ffi_transform_api.RunPass(self, mod)

  File "/home/wda/incubator-tvm-20200407/python/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)#5}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)+0x27) [0x7f4db37b17b7]
  [bt] (7) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::annotate_target::AnnotateTargetWrapper::VisitExpr_(tvm::relay::FunctionNode const*)+0x2a) [0x7f4db36222da]
  [bt] (6) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)+0xccf) [0x7f4db37ace7f]
  [bt] (5) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)+0x96) [0x7f4db37af996]
  [bt] (4) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x94) [0x7f4db37b44a4]
  [bt] (3) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)+0x27) [0x7f4db37b1807]
  [bt] (2) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::annotate_target::AnnotateTargetWrapper::VisitExpr_(tvm::relay::CallNode const*)+0x7c7) [0x7f4db3623c77]
  [bt] (1) /home/wda/incubator-tvm-20200407/build/libtvm.so(tvm::relay::annotate_target::AnnotateTargetWrapper::IsSupported(tvm::RelayExpr const&)+0x2bb) [0x7f4db361fdeb]
  [bt] (0) /home/wda/incubator-tvm-20200407/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f4db2fbd9c2]
  File "/home/wda/incubator-tvm-20200407/include/tvm/runtime/object.h", line 872
TVMError: Check failed: ref->template IsInstance<typename SubRef: :ContainerType>(): Downcast from GlobalVar to relay.Op failed.

So please help me, thank you. @ lhutton1@ comaniac@ zhiics@ tqchen

Could you print ‘mod’ before running the AnnotateTarget pass? It looks like it’s encountering a GlobarVar which isn’t handled yet.

I just use the model from https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/classify_image_graph_def-with_shapes.pb I can print the ‘mod’ before running the AnnotateTarget pass, but it’s too log. I think now AnnotateTarget pass can’t handle GlobarVar.

I have known the reason and I found this probleam has been fixed yesterday.

And now, I have got another error. I stil use model from https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/classify_image_graph_def-with_shapes.pb

Now, AnnotateTarget pass is ok, but PartitionGraph pass wrong.

tvm._ffi.base.TVMError: TVMError: Cannot find the corresponding region for start annotation:

Could you post the code you’re using so that I can reproduce this?

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Compile Tensorflow Models
=========================
This article is an introductory tutorial to deploy tensorflow models with TVM.

For us to begin with, tensorflow python module is required to be installed.

Please refer to https://www.tensorflow.org/install
"""

# tvm, relay
import tvm
from tvm import relay

# os and numpy
import numpy as np

# Tensorflow imports
import tensorflow as tf
try:
    tf_compat_v1 = tf.compat.v1
except ImportError:
    tf_compat_v1 = tf

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing

# Base location for model related files.
#repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
data_dir = "./inception/"
# Test image
img_name = 'elephant-299.jpg'
#image_url = os.path.join(repo_base, img_name)

######################################################################
# Tutorials
# ---------
# Please refer docs/frontend/tensorflow.md for more details for various models
# from tensorflow.

model_name = 'classify_image_graph_def-with_shapes.pb'
#model_url = os.path.join(repo_base, model_name)

# Image label map
map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
#map_proto_url = os.path.join(repo_base, map_proto)

# Human readable text for labels
label_map = 'imagenet_synset_to_human_label_map.txt'
#label_map_url = os.path.join(repo_base, label_map)

# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)

######################################################################
# Download required files
# -----------------------
# Download files listed above.
from tvm.contrib.download import download_testdata

# img_path = download_testdata(image_url, img_name, module='data')
# model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])
# map_proto_path = download_testdata(map_proto_url, map_proto, module='data')
# label_path = download_testdata(label_map_url, label_map, module='data')
img_path = data_dir+img_name
model_path = data_dir + model_name
map_proto_path = data_dir + map_proto
label_path = data_dir + label_map

######################################################################
# Import model
# ------------
# Creates tensorflow graph definition from protobuf file.

with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf_compat_v1.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
    # Add shapes to the graph.
    with tf_compat_v1.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')

######################################################################
# Decode image
# ------------
# .. note::
#
#   tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
#   JpegDecode is bypassed (just return source node).
#   Hence we supply decoded frame to TVM instead.
#

from PIL import Image
image = Image.open(img_path).resize((299, 299))

x = np.array(image)

######################################################################
# Import the graph to Relay
# -------------------------
# Import tensorflow graph definition to relay frontend.
#
# Results:
#   sym: relay expr for given tensorflow protobuf.
#   params: params converted from tensorflow params (tensor protobuf).
shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'}
mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict)
print(mod)
from tvm.relay import transform
mod = transform.AnnotateTarget("dnnl")(mod)
print("**********annotated mod:********\n", mod)
#mod = transform.MergeCompilerRegions()(mod)
#print("**********merged mod:********\n", mod)
mod = transform.PartitionGraph()(mod)
print("Tensorflow protobuf imported to relay frontend.")
######################################################################
# Relay Build
# -----------
# Compile the graph to llvm target with given input specification.
#
# Results:
#   graph: Final graph after compilation.
#   params: final params after compilation.
#   lib: target library which can be deployed on target with TVM runtime.

with relay.build_config(opt_level=4):
    graph, lib, params = relay.build(mod,
                                     target=target,
                                     target_host=target_host,
                                     params=params)

######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the compiled model on target.

#from tvm.contrib import graph_runtime
from tvm.contrib.debugger import debug_runtime as graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))

######################################################################
# Process the output
# ------------------
# Process the model output to human readable text for InceptionV1.
predictions = tvm_output.asnumpy()
predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
                                    uid_lookup_path=label_path)

# Print top 5 predictions from TVM output.
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
    human_string = node_lookup.id_to_string(node_id)
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))

Thank you very much :grinning:

The reason is that you imported a model from TensorFlow. When converting a model from other frameworks, it may preserve many unused functions in the model and this will cause failures. Add one line in your code could resolve this issue:

from tvm.relay import transform
mod = transform.RemoveUnusedFunctions()(mod)
mod = transform.AnnotateTarget("dnnl")(mod)

After that, your program will go through all BYOC passes but failed when building a module. It seems like the limitation to the current DNNL codegen, because it’s not supposed to be used in practice but just for illustration purpose. I’ll take another look later and make a quick fix if the issue it straightforward.

Thank you very much. I have found the reason why failed when building a module with DNNL codegen. const auto* call = func->body.as<CallNode>();

func’s body is not always a CallNode, so we need deal with it.