[Tensorflow] Failed to run autoTVM with tensorflow Converted model

Here is the component version:

  • llvm 10.0
  • tensorflow 1.14
  • tvm: 36a0bf94cf93c5d4b067ae4359b8807ae2dde2d2

Download the freezed model here:

wget https://zenodo.org/record/2535873/files/resnet50_v1.pb

Here is the script to run autoTVM:

# tvm, relay
import tvm
from tvm import te
from tvm import relay

# os and numpy
import numpy as np
import os.path

# Tensorflow imports
import tensorflow as tf
try:
    tf_compat_v1 = tf.compat.v1
except ImportError:
    tf_compat_v1 = tf

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.python.framework import dtypes

target = 'llvm'
# target_host = 'llvm'
target_host = None
layout = None
# ctx = tvm.cpu(0)


model_path = "./resnet50_v1.pb"
INPUTS = 'input_tensor'
OUTPUTS = 'softmax_tensor'
Batchsize = 128

with tf_compat_v1.gfile.GFile(model_path, 'rb') as f:
    graph_def = tf_compat_v1.GraphDef()
    graph_def.ParseFromString(f.read())

    graph_def = optimize_for_inference(graph_def, [INPUTS],
                                           [OUTPUTS], dtypes.float32.as_datatype_enum, False)

    graph = tf_compat_v1.import_graph_def(graph_def, name='')

    # Call the utility to import the graph definition into default graph.
    graph_def = tf_testing.ProcessGraphDefParam(graph_def)

    # Add shapes to the graph.
    with tf_compat_v1.Session() as sess:
        graph_def = tf_testing.AddShapesToGraphDef(sess, OUTPUTS)

shape_dict = {INPUTS: (Batchsize, 224, 224, 3)}
#shape_dict = None

mod, params = relay.frontend.from_tensorflow(graph_def,
                                             layout=layout,
                                             shape=shape_dict)

print("Tensorflow protobuf imported to relay frontend.")

direct_convert = False # will use autoTVM
if direct_convert:
    with tvm.transform.PassContext(opt_level=3):
        graph, lib, params = relay.build(mod,
                                         target=target,
                                         target_host=target_host,
                                         params=params)
else:
    ## AutoTVM
    from tvm import autotvm
    from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
    from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
    #import tvm.contrib.graph_runtime as runtime

    model_name = "resnet50_v15"
    log_file = "%s.log" % model_name
    graph_opt_sch_file = "%s_graph_opt.log" % model_name

    num_threads = 4
    os.environ["TVM_NUM_THREADS"] = str(num_threads)
    input_shape = (Batchsize, 224, 224, 3)
    output_shape = (Batchsize, 1001)

    tuning_option = {
        'log_filename': log_file,
        'tuner': 'random',
        'early_stopping': None,

        'measure_option': autotvm.measure_option(
            builder=autotvm.LocalBuilder(),
            runner=autotvm.LocalRunner(number=10, repeat=1,
                                       min_repeat_ms=1000),
        ),
    }

    # You can skip the implementation of this function for this tutorial.
    def tune_kernels(tasks,
                     measure_option,
                     tuner='gridsearch',
                     early_stopping=None,
                     log_filename='tuning.log'):

        for i, task in enumerate(tasks):
            prefix = "[Task %2d/%2d] " % (i+1, len(tasks))

            # create tuner
            if tuner == 'xgb' or tuner == 'xgb-rank':
                tuner_obj = XGBTuner(task, loss_type='rank')
            elif tuner == 'ga':
                tuner_obj = GATuner(task, pop_size=50)
            elif tuner == 'random':
                tuner_obj = RandomTuner(task)
            elif tuner == 'gridsearch':
                tuner_obj = GridSearchTuner(task)
            else:
                raise ValueError("Invalid tuner: " + tuner)

            # do tuning
            n_trial=len(task.config_space)
            tuner_obj.tune(n_trial=n_trial,
                           early_stopping=early_stopping,
                           measure_option=measure_option,
                           callbacks=[
                               autotvm.callback.progress_bar(n_trial, prefix=prefix),
                               autotvm.callback.log_to_file(log_filename)])


    # Use graph tuner to achieve graph level optimal schedules
    # Set use_DP=False if it takes too long to finish.
    def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
        target_op = [relay.op.get("nn.conv2d"),]
        Tuner = DPTuner if use_DP else PBQPTuner
        executor = Tuner(graph, {INPUTS: dshape}, records, target_op, target)
        executor.benchmark_layout_transform(min_exec_num=2000)
        executor.run()
        executor.write_opt_sch2record_file(opt_sch_file)

    def tune_and_evaluate(tuning_opt, mod, params, data_shape, out_shape):
        # extract workloads from relay program
        print("Extract tasks...")
        #mod, params, data_shape, out_shape = get_network(model_name, batch_size)
        tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                                  params=params,
                                                  ops=(relay.op.get("nn.conv2d"),))

        # run tuning tasks
        tune_kernels(tasks, **tuning_opt)
        tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)

        # compile kernels with graph-level best records
        with autotvm.apply_graph_best(graph_opt_sch_file):
            print("Compile...")
            with tvm.transform.PassContext(opt_level=3):
                graph, lib, params = relay.build_module.build(
                    mod, target=target, params=params)
        return graph, lib, params

    graph, lib, params = tune_and_evaluate(tuning_option, mod, params, input_shape, output_shape)

# save the graph, lib and params into separate files
from tvm.contrib import util
import os

os.system("rm -rf ./export && mkdir export")
path_lib = "./export/deploy_lib.tar"
lib.export_library(path_lib)
with open("./export/deploy_graph.json", "w") as fo:
    fo.write(graph)
with open("./export/deploy_param.params", "wb") as fo:
    fo.write(relay.save_param_dict(params))

Failed to with Error message:

Traceback (most recent call last):
  File "tensorflow_tvm.py", line 183, in <module>
    graph, lib, params = tune_and_evaluate(tuning_option, mod, params, input_shape, output_shape)
  File "tensorflow_tvm.py", line 173, in tune_and_evaluate
    tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)
  File "tensorflow_tvm.py", line 158, in tune_graph
    executor = Tuner(graph, {INPUTS: dshape}, records, target_op, target)
  File "/home/lesliefang/tvm/tvm/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py", line 43, in __init__
    super(DPTuner, self).__init__(*args, **kwargs)
  File "/home/lesliefang/tvm/tvm/python/tvm/autotvm/graph_tuner/base_graph_tuner.py", line 149, in __init__
    expr2graph(graph, self._target_ops, node_dict, self._node_list)
  File "/home/lesliefang/tvm/tvm/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py", line 65, in expr2graph
    task_name, args = env.task_collection[task_pos]
IndexError: list index out of range

It seems env.task_collection is empty. And I am not sure if it relates with this issue: https://github.com/apache/incubator-tvm/pull/5938

Any suggestions to debug this issue? I find the env.get_tasks() https://github.com/apache/incubator-tvm/blob/575a3835315a533a19e871ee913f01142befebab/python/tvm/autotvm/task/relay_integration.py#L148 is empty

It seems like nn.conv2d converted from my model has been taken as the un-tunable op in autoTVM. Since if I use nn.dense instead of nn.conv2d. it will pass this error but failed with another hard-code error.

Solved. Relates with the memory layout. I will focus on the performance comparison firstly. And dig it deep further. Thanks.

hi, how did you fix the err? can you share?

how did you fix the error? how to change the memory layout? NHWC —> NCHW ? and how to set data_shape? looking forward your reply

@turboLIU @david Sorry for the delay of response. Yes, it relates with the memory layout. I have add the graph layout pass in relay to solve this problem.

desired_layouts = {'nn.conv2d': ['NCHW', 'default']} # desired input and kernel layout
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts),
                                relay.transform.FoldConstant()])