Large Runtime Discrepancy between Profiler and Tutorial Script

I ran tuning (autotvm) and inference for a ResNet-50 model in two ways.

  1. Running the tutorial script in tutorials/autotvm/tune_relay_x86.py with some modifications.
  2. Running the in-built profiler used in tests/python/unittest/test_runtime_profiling.py

The total time reported by (2) is almost 4x the time reported by (1). I can’t make out the reason behind this difference. Does anyone have any suggestions?

Here is the modified code in tune_relay_x86 for (1):

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
.. _tune_relay_x86:

Auto-tuning a Convolutional Network for x86 CPU
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_, `Eddie Yan <https://github.com/eqy>`_

This is a tutorial about how to tune convolution neural network
for x86 CPU.

Note that this tutorial will not run on Windows or recent versions of macOS. To
get it to run, you will need to wrap the body of this tutorial in a :code:`if
__name__ == "__main__":` block.
"""
import os
import numpy as np

import tvm
from tvm import relay, autotvm
from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
import tvm.contrib.graph_executor as runtime
# import tvm.contrib.debugger.debug_executor as runtime

#################################################################
# Define network
# --------------
# First we need to define the network in relay frontend API.
# We can either load some pre-defined network from :code:`relay.testing`
# or building :any:`relay.testing.resnet` with relay.
# We can also load models from MXNet, ONNX and TensorFlow.
#
# In this tutorial, we choose resnet-18 as tuning example.


def get_network(name, batch_size):
    """Get the symbol definition and random weight of a network"""
    input_shape = (batch_size, 3, 224, 224)
    output_shape = (batch_size, 1000)

    if "resnet" in name:
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.resnet.get_workload(
            num_layers=n_layer, batch_size=batch_size, dtype=dtype
        )
    elif "vgg" in name:
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.vgg.get_workload(
            num_layers=n_layer, batch_size=batch_size, dtype=dtype
        )
    elif name == "mobilenet":
        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == "squeezenet_v1.1":
        mod, params = relay.testing.squeezenet.get_workload(
            batch_size=batch_size, version="1.1", dtype=dtype
        )
    elif name == "inception_v3":
        input_shape = (batch_size, 3, 299, 299)
        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == "mxnet":
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model

        block = get_model("resnet18_v1", pretrained=True)
        mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype)
        net = mod["main"]
        net = relay.Function(
            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
        )
        mod = tvm.IRModule.from_expr(net)
    else:
        raise ValueError("Unsupported network: " + name)

    return mod, params, input_shape, output_shape


# Replace "llvm" with the correct target of your CPU.
# For example, for AWS EC2 c5 instance with Intel Xeon
# Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512".
# For AWS EC2 c4 instance with Intel Xeon E5-2666 v3, it should be
# "llvm -mcpu=core-avx2".
target = "llvm"

batch_size = 1
dtype = "float32"
model_name = "resnet-18"
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name

# Set the input name of the graph
# For ONNX models, it is typically "0".
input_name = "data"

# Set number of threads used for tuning based on the number of
# physical CPU cores on your machine.
num_threads = 1
os.environ["TVM_NUM_THREADS"] = str(num_threads)


#################################################################
# Configure tensor tuning settings and create tasks
# -------------------------------------------------
# To get better kernel execution performance on x86 CPU,
# we need to change data layout of convolution kernel from
# "NCHW" to "NCHWc". To deal with this situation, we define
# conv2d_NCHWc operator in topi. We will tune this operator
# instead of plain conv2d.
#
# We will use local mode for tuning configuration. RPC tracker
# mode can be setup similarly to the approach in
# :ref:`tune_relay_arm` tutorial.
#
# To perform a precise measurement, we should repeat the measurement several
# times and use the average of results. In addition, we need to flush the cache
# for the weight tensors between repeated measurements. This can make the measured
# latency of one operator closer to its actual latency during end-to-end inference.

tuning_option = {
    "log_filename": log_file,
    "tuner": "random",
    "early_stopping": None,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(
            number=1, repeat=5, min_repeat_ms=0, enable_cpu_cache_flush=True
        ),
    ),
}


# You can skip the implementation of this function for this tutorial.
def tune_kernels(
    tasks, measure_option, tuner="gridsearch", early_stopping=None, log_filename="tuning.log"
):

    for i, task in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))

        # create tuner
        if tuner == "xgb" or tuner == "xgb-rank":
            tuner_obj = XGBTuner(task, loss_type="rank")
        elif tuner == "ga":
            tuner_obj = GATuner(task, pop_size=50)
        elif tuner == "random":
            tuner_obj = RandomTuner(task)
        elif tuner == "gridsearch":
            tuner_obj = GridSearchTuner(task)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        # do tuning
        n_trial = len(task.config_space)
        tuner_obj.tune(
            n_trial=n_trial,
            early_stopping=early_stopping,
            measure_option=measure_option,
            callbacks=[
                autotvm.callback.progress_bar(n_trial, prefix=prefix),
                autotvm.callback.log_to_file(log_filename),
            ],
        )


# Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
    target_op = [
        relay.op.get("nn.conv2d"),
    ]
    Tuner = DPTuner if use_DP else PBQPTuner
    executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
    # executor.benchmark_layout_transform(min_exec_num=2000, layout_records="resnet-18.log", infer_layout=True)
    executor.benchmark_layout_transform(min_exec_num=100)
    executor.run()
    executor.write_opt_sch2record_file(opt_sch_file)


########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.


def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, data_shape, out_shape = get_network(model_name, batch_size)
    tasks = autotvm.task.extract_from_program(
        mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
    )

    # run tuning tasks
    tune_kernels(tasks, **tuning_opt)
    tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)

    # compile kernels with graph-level best records
    with autotvm.apply_graph_best(graph_opt_sch_file):
        print("Compile...")
        with tvm.transform.PassContext(opt_level=3):
            lib = relay.build_module.build(mod, target=target, params=params)

        # upload parameters to device
        dev = tvm.cpu()
        data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
        module = runtime.GraphModule(lib["default"](dev))
        module.set_input(input_name, data_tvm)

        # evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", dev, number=100, repeat=3)
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        print(
            "Mean inference time (std dev): %.2f ms (%.2f ms)"
            % (np.mean(prof_res), np.std(prof_res))
        )

tune_and_evaluate(tuning_option)

Here’s the code for (2):

import numpy as np
import pytest
from io import StringIO
import csv
import os
import json

import tvm
import tvm.testing
from tvm.runtime import profiler_vm
from tvm import relay, autotvm
from tvm.relay.testing import mlp
from tvm.contrib.debugger import debug_executor
from tvm.autotvm.graph_tuner import DPTuner

@tvm.testing.parametrize_targets
def test_resnet(target, dev):
    n_layer = 50
    batch_size = 1
    input_shape = (batch_size, 3, 224, 224)
    dtype = "float32"

    log = "resnet-50.log"
    opt = "resnet-50.opt"
    early_stopping = None
    measure_option = autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(
            number=1, repeat=5, min_repeat_ms=0, enable_cpu_cache_flush=True
        ),
    )

    mod, params = relay.testing.resnet.get_workload(
        num_layers=n_layer, batch_size=batch_size, dtype=dtype
    )
    data = np.random.uniform(size=input_shape).astype(dtype)

    tasks = autotvm.task.extract_from_program(mod['main'], target=target,
        params=params, ops=(relay.op.get("nn.conv2d"),)
    )

    for i, task in enumerate(tasks):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
        tuner = autotvm.tuner.RandomTuner(task)
        n_trial = len(task.config_space)
        tuner.tune(
            n_trial=n_trial,
            early_stopping=early_stopping,
            measure_option=measure_option,
            callbacks=[
                autotvm.callback.progress_bar(n_trial, prefix=prefix),
                autotvm.callback.log_to_file(log)
            ]
        )

    target_op = [
        relay.op.get("nn.conv2d"),
    ]
    executor = DPTuner(mod['main'], {"data": input_shape},
        log, target_op, target
    )

    executor.benchmark_layout_transform(min_exec_num=100)
    executor.run()
    executor.write_opt_sch2record_file(opt)

    with autotvm.apply_graph_best(opt):
        print("Compiling Kernels with graph-level best records")
        with tvm.transform.PassContext(opt_level=3):
            exe = relay.build(mod, target, params=params)
            gr = debug_executor.create(exe.get_graph_json(), exe.lib, dev)
            report = gr.profile(data=data)
            print(report)

if __name__ == "__main__":
    test_resnet("llvm", tvm.cpu())

The outputs of the two are as follows: (1):

Extract tasks...
[Task  1/20]  Current/Best:   13.86/  28.03 GFLOPS | Progress: (308/308) | 217.99 s Done.
[Task  2/20]  Current/Best:   19.79/  27.28 GFLOPS | Progress: (980/980) | 401.33 s Done.
[Task  3/20]  Current/Best:   17.16/  27.93 GFLOPS | Progress: (980/980) | 572.77 s Done.
[Task  4/20]  Current/Best:   20.46/  26.31 GFLOPS | Progress: (1260/1260) | 1023.95 s Done.
[Task  5/20]  Current/Best:   15.34/  26.88 GFLOPS | Progress: (1260/1260) | 662.98 s Done.
[Task  6/20]  Current/Best:    4.07/  28.04 GFLOPS | Progress: (1152/1152) | 555.31 s Done.
[Task  7/20]  Current/Best:   11.46/  28.00 GFLOPS | Progress: (1024/1024) | 612.87 s Done.
[Task  8/20]  Current/Best:   18.05/  26.66 GFLOPS | Progress: (1440/1440) | 1615.15 s Done.
[Task  9/20]  Current/Best:   22.34/  26.67 GFLOPS | Progress: (1280/1280) | 1242.93 s Done.
[Task 10/20]  Current/Best:   22.39/  27.10 GFLOPS | Progress: (1280/1280) | 660.74 s Done.
[Task 11/20]  Current/Best:   13.05/  26.74 GFLOPS | Progress: (1080/1080) | 532.07 s Done.
[Task 12/20]  Current/Best:   12.13/  27.18 GFLOPS | Progress: (972/972) | 630.99 s Done.
[Task 13/20]  Current/Best:   14.14/  24.94 GFLOPS | Progress: (1320/1320) | 1671.61 s Done.
[Task 14/20]  Current/Best:    6.39/  26.33 GFLOPS | Progress: (1188/1188) | 1244.71 s Done.
[Task 15/20]  Current/Best:    6.52/  26.78 GFLOPS | Progress: (1188/1188) | 628.08 s Done.
[Task 16/20]  Current/Best:   21.34/  25.89 GFLOPS | Progress: (880/880) | 428.56 s Done.
[Task 17/20]  Current/Best:    8.56/  26.87 GFLOPS | Progress: (800/800) | 655.81 s Done.
[Task 18/20]  Current/Best:    4.70/  25.93 GFLOPS | Progress: (1056/1056) | 1342.64 s Done.
[Task 19/20]  Current/Best:    5.44/  26.70 GFLOPS | Progress: (960/960) | 1080.96 s Done.
[Task 20/20]  Current/Best:   15.95/  25.95 GFLOPS | Progress: (960/960) | 514.45 s Done.
2021-08-09 03:49:23,155 INFO Start to benchmark layout transformation...
2021-08-09 03:59:04,193 INFO Benchmarking layout transformation successful.
2021-08-09 03:59:04,233 INFO Start to run dynamic programming algorithm...
2021-08-09 03:59:04,234 INFO Start forward pass...
2021-08-09 03:59:05,056 INFO Finished forward pass.
2021-08-09 03:59:05,056 INFO Start backward pass...
2021-08-09 03:59:05,076 INFO Finished backward pass...
2021-08-09 03:59:05,077 INFO Finished DPExecutor run.
2021-08-09 03:59:05,091 INFO Writing optimal schedules to resnet-50_graph_opt.log successfully.
Compile...
Config for target=llvm -keys=cpu -link-params=0, workload=('dense_nopack.x86', ('TENSOR', (1, 2048), 'float32'), ('TENSOR', (1000, 2048), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Config for target=llvm -keys=cpu -link-params=0, workload=('dense_pack.x86', ('TENSOR', (1, 2048), 'float32'), ('TENSOR', (1000, 2048), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Evaluate inference time cost...
Mean inference time (std dev): 328.58 ms (0.60 ms)

(2):

[Task  1/20]  Current/Best:   86.48/  99.59 GFLOPS | Progress: (308/308) | 157.22 s Done.
[Task  2/20]  Current/Best:   41.10/  98.56 GFLOPS | Progress: (980/980) | 356.97 s Done.
[Task  3/20]  Current/Best:   28.23/  95.20 GFLOPS | Progress: (980/980) | 395.32 s Done.
[Task  4/20]  Current/Best:    9.76/  99.54 GFLOPS | Progress: (1260/1260) | 884.07 s Done.
[Task  5/20]  Current/Best:   15.84/  90.33 GFLOPS | Progress: (1260/1260) | 521.19 s Done.
[Task  6/20]  Current/Best:   14.44/  97.26 GFLOPS | Progress: (1152/1152) | 498.70 s Done.
[Task  7/20]  Current/Best:   63.92/  91.68 GFLOPS | Progress: (1024/1024) | 441.03 s Done.
[Task  8/20]  Current/Best:   25.57/  89.37 GFLOPS | Progress: (1440/1440) | 1493.47 s Done.
[Task  9/20]  Current/Best:   29.63/  93.04 GFLOPS | Progress: (1280/1280) | 1212.33 s Done.
[Task 10/20]  Current/Best:   23.18/  86.05 GFLOPS | Progress: (1280/1280) | 568.20 s Done.
[Task 11/20]  Current/Best:   31.45/  85.33 GFLOPS | Progress: (1080/1080) | 491.78 s Done.
[Task 12/20]  Current/Best:   45.85/  92.85 GFLOPS | Progress: (972/972) | 446.99 s Done.
[Task 13/20]  Current/Best:   60.31/  87.49 GFLOPS | Progress: (1320/1320) | 1350.96 s Done.
[Task 14/20]  Current/Best:   17.88/  95.81 GFLOPS | Progress: (1188/1188) | 1168.97 s Done.
[Task 15/20]  Current/Best:   65.15/  97.88 GFLOPS | Progress: (1188/1188) | 518.89 s Done.
[Task 16/20]  Current/Best:   18.00/  83.59 GFLOPS | Progress: (880/880) | 379.54 s Done.
[Task 17/20]  Current/Best:   35.93/  79.80 GFLOPS | Progress: (800/800) | 544.13 s Done.
[Task 18/20]  Current/Best:   25.29/  86.51 GFLOPS | Progress: (1056/1056) | 1132.86 s Done.
[Task 19/20]  Current/Best:   44.40/  92.23 GFLOPS | Progress: (960/960) | 1030.75 s Done.
[Task 20/20]  Current/Best:   52.05/  80.32 GFLOPS | Progress: (960/960) | 435.15 s Done.
2021-09-24 13:03:43,378 INFO Start to benchmark layout transformation...
2021-09-24 13:09:58,975 INFO Benchmarking layout transformation successful.
2021-09-24 13:09:59,002 INFO Start to run dynamic programming algorithm...
2021-09-24 13:09:59,013 INFO Start forward pass...
2021-09-24 13:09:59,816 INFO Finished forward pass.
2021-09-24 13:09:59,817 INFO Start backward pass...
2021-09-24 13:09:59,830 INFO Finished backward pass...
2021-09-24 13:09:59,830 INFO Finished DPExecutor run.
2021-09-24 13:09:59,832 INFO Writing optimal schedules to resnet-50.opt successfully.
Compiling Kernels with graph-level best records
Config for target=llvm -keys=cpu -link-params=0, workload=('dense_nopack.x86', ('TENSOR', (1, 2048), 'float32'), ('TENSOR', (1000, 2048), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Config for target=llvm -keys=cpu -link-params=0, workload=('dense_pack.x86', ('TENSOR', (1, 2048), 'float32'), ('TENSOR', (1000, 2048), 'float32'), None, 'float32') is missing in ApplyGraphBest context. A fallback configuration is used, which may bring great performance regression.
Name                                                                   Duration (us)  Percent  data_layout  Device   layout  Count  src_layout  dst_layout                                                                                                                                                Argument Shapes              Hash  kernel_layout  out_layout  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_7               3,98,550.80    28.25     NCHW128c    cpu0               3                                                                           float32[1, 1, 28, 28, 128], float32[16, 1, 3, 3, 128, 8], float32[1, 16, 1, 1, 8], float32[1, 16, 28, 28, 8]  c8ab54df360cbd39     OIHW128i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu                 2,97,997.60    21.13       NCHW1c    cpu0               1                                                                           float32[1, 3, 224, 224, 1], float32[4, 3, 7, 7, 1, 16], float32[1, 4, 1, 1, 16], float32[1, 4, 112, 112, 16]  4f0ddf6181725272      OIHW1i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_6               1,73,025.86    12.27      NCHW16c    cpu0               3                                                                           float32[1, 32, 28, 28, 16], float32[16, 32, 1, 1, 16, 8], float32[1, 16, 1, 1, 8], float32[1, 16, 28, 28, 8]  e0220255e528d98d      OIHW16i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_1                       1,50,020.14    10.64       NCHW8c    cpu0               3                                                                        float32[1, 16, 28, 28, 8], float32[32, 16, 1, 1, 8, 16], float32[1, 32, 28, 28, 16], float32[1, 32, 28, 28, 16]  6beba43d92784786      OIHW8i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_2                 65,746.29     4.66      NCHW32c    cpu0               3                                                                                float32[1, 2, 56, 56, 32], float32[8, 2, 3, 3, 32, 8], float32[1, 8, 1, 1, 8], float32[1, 8, 56, 56, 8]  a052a5b4742ae7bb      OIHW32i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_add_nn_relu_1             52,376.01     3.71       NCHW8c    cpu0               1                                              float32[1, 16, 28, 28, 8], float32[32, 16, 1, 1, 8, 16], float32[1, 32, 28, 28, 16], float32[1, 32, 1, 1, 16], float32[1, 32, 28, 28, 16]  faa415ce8e443d42      OIHW8i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_5                 50,519.11     3.58      NCHW32c    cpu0               1                                                                             float32[1, 4, 28, 28, 32], float32[16, 4, 3, 3, 32, 8], float32[1, 16, 1, 1, 8], float32[1, 16, 28, 28, 8]  0dfbf6e394761131      OIHW32i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_2                         48,376.87     3.43       NCHW8c    cpu0               5                                                                        float32[1, 32, 14, 14, 8], float32[64, 32, 1, 1, 8, 16], float32[1, 64, 14, 14, 16], float32[1, 64, 14, 14, 16]  c3c48546ccd1c8e4      OIHW8i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_2                             36,190.51     2.57      NCHW16c    cpu0               1                                                                                                  float32[1, 32, 28, 28, 16], float32[64, 32, 1, 1, 16, 16], float32[1, 64, 14, 14, 16]  b6e66601adaeb1e3     OIHW16i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add                           29,154.91     2.07       NCHW8c    cpu0               2                                                                          float32[1, 8, 56, 56, 8], float32[16, 8, 1, 1, 8, 16], float32[1, 16, 56, 56, 16], float32[1, 16, 56, 56, 16]  e56ff43c45465b27      OIHW8i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc                               28,186.36     2.00      NCHW16c    cpu0               1                                                                                                    float32[1, 4, 56, 56, 16], float32[16, 4, 1, 1, 16, 16], float32[1, 16, 56, 56, 16]  7661eb48c0b8a7e6     OIHW16i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_9                 17,799.89     1.26      NCHW32c    cpu0               6                                                                             float32[1, 8, 14, 14, 32], float32[32, 8, 3, 3, 32, 8], float32[1, 32, 1, 1, 8], float32[1, 32, 14, 14, 8]  e5f43040054fe780      OIHW32i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_12                 9,253.03     0.66     NCHW512c    cpu0               3                                                                            float32[1, 1, 7, 7, 512], float32[128, 1, 3, 3, 512, 4], float32[1, 128, 1, 1, 4], float32[1, 128, 7, 7, 4]  9c69f3e4929f374e     OIHW512i4o      NCHW4c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_8                  9,157.90     0.65      NCHW16c    cpu0               1                                                                           float32[1, 32, 28, 28, 16], float32[8, 32, 1, 1, 16, 32], float32[1, 8, 1, 1, 32], float32[1, 8, 14, 14, 32]  69a0179356dd24cf     OIHW16i32o     NCHW32c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1                  7,720.52     0.55      NCHW16c    cpu0               1                                                                             float32[1, 4, 56, 56, 16], float32[2, 4, 1, 1, 16, 32], float32[1, 2, 1, 1, 32], float32[1, 2, 56, 56, 32]  f298951e5a0123be     OIHW16i32o     NCHW32c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_3                  7,692.81     0.55      NCHW16c    cpu0               2                                                                           float32[1, 16, 56, 56, 16], float32[2, 16, 1, 1, 16, 32], float32[1, 2, 1, 1, 32], float32[1, 2, 56, 56, 32]  bbcb68bd1349e736     OIHW16i32o     NCHW32c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_10                 7,405.50     0.52      NCHW16c    cpu0               5                                                                           float32[1, 64, 14, 14, 16], float32[32, 64, 1, 1, 16, 8], float32[1, 32, 1, 1, 8], float32[1, 32, 14, 14, 8]  c15647de03f3cf01      OIHW16i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_13                 3,128.43     0.22       NCHW8c    cpu0               2                                                                               float32[1, 256, 7, 7, 8], float32[64, 256, 1, 1, 8, 8], float32[1, 64, 1, 1, 8], float32[1, 64, 7, 7, 8]  87141adc82ec7e11       OIHW8i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_3                          2,934.52     0.21       NCHW8c    cpu0               2                                                                              float32[1, 64, 7, 7, 8], float32[256, 64, 1, 1, 8, 8], float32[1, 256, 7, 7, 8], float32[1, 256, 7, 7, 8]  13aacf6ea462872b       OIHW8i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_3                              2,771.56     0.20    NCHW1024c    cpu0               1                                                                                                  float32[1, 1, 14, 14, 1024], float32[256, 1, 1, 1, 1024, 8], float32[1, 256, 7, 7, 8]  493c374dd5e37c2b    OIHW1024i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_1                              2,603.12     0.18      NCHW16c    cpu0               1                                                                                                  float32[1, 16, 56, 56, 16], float32[32, 16, 1, 1, 16, 16], float32[1, 32, 28, 28, 16]  f8d89936127411c1     OIHW16i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_multiply_add_nn_relu       1,460.02     0.10       NCHW8c    cpu0               1                          float32[1, 64, 7, 7, 8], float32[256, 64, 1, 1, 8, 8], float32[1, 256, 7, 7, 8], float32[1, 256, 1, 1, 8], float32[1, 256, 1, 1, 8], float32[1, 256, 7, 7, 8]  5641ba265b704699       OIHW8i8o      NCHW8c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_add_nn_relu                1,386.45     0.10       NCHW8c    cpu0               1                                                float32[1, 8, 56, 56, 8], float32[16, 8, 1, 1, 8, 16], float32[1, 16, 56, 56, 16], float32[1, 16, 1, 1, 16], float32[1, 16, 56, 56, 16]  3d774adf39794e79      OIHW8i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_add_nn_relu_2              1,385.95     0.10       NCHW8c    cpu0               1                                              float32[1, 32, 14, 14, 8], float32[64, 32, 1, 1, 8, 16], float32[1, 64, 14, 14, 16], float32[1, 64, 1, 1, 16], float32[1, 64, 14, 14, 16]  f3caf71a3eca1b54      OIHW8i16o     NCHW16c  
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_4                    885.80     0.06      NCHW16c    cpu0               1                                                                           float32[1, 16, 56, 56, 16], float32[4, 16, 1, 1, 16, 32], float32[1, 4, 1, 1, 32], float32[1, 4, 28, 28, 32]  66ef347de4f87e0d     OIHW16i32o     NCHW32c  
tvmgen_default_fused_add_nn_relu                                              874.12     0.06                 cpu0               2                                                                                                       float32[1, 16, 56, 56, 16], float32[1, 16, 1, 1, 16], float32[1, 16, 56, 56, 16]  e907ce81104cda7a                             
tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_11                   709.64     0.05    NCHW1024c    cpu0               1                                                                        float32[1, 1, 14, 14, 1024], float32[32, 1, 1, 1, 1024, 16], float32[1, 32, 1, 1, 16], float32[1, 32, 7, 7, 16]  39427b820e22e106   OIHW1024i16o     NCHW16c  
tvmgen_default_fused_add_nn_relu_1                                            499.39     0.04                 cpu0               3                                                                                                       float32[1, 32, 28, 28, 16], float32[1, 32, 1, 1, 16], float32[1, 32, 28, 28, 16]  848825acfc73218b                             
tvmgen_default_fused_nn_contrib_dense_pack_add                                414.50     0.03                 cpu0               1                                                                                                              float32[1, 2048], float32[100, 2048, 10], float32[1000], float32[1, 1000]  9911caf9c87a4f0b                             
tvmgen_default_fused_nn_max_pool2d_add_nn_relu                                330.88     0.02                 cpu0  NCHW16c      1                                                                                                        float32[1, 4, 112, 112, 16], float32[1, 4, 1, 1, 16], float32[1, 4, 56, 56, 16]  6f701a4fa071030f                             
tvmgen_default_fused_add_nn_relu_2                                            190.15     0.01                 cpu0               5                                                                                                       float32[1, 64, 14, 14, 16], float32[1, 64, 1, 1, 16], float32[1, 64, 14, 14, 16]  f12067172f61c850                             
tvmgen_default_fused_layout_transform                                         121.78     0.01                 cpu0               3      NCHW8c    NCHW128c                                                                                                          float32[1, 16, 28, 28, 8], float32[1, 1, 28, 28, 128]  7648eead5ff7609b                             
tvmgen_default_fused_layout_transform_1                                        76.12     0.01                 cpu0               5      NCHW8c     NCHW32c                                                                                                           float32[1, 32, 14, 14, 8], float32[1, 8, 14, 14, 32]  c928642d4751d75f                             
tvmgen_default_fused_add_layout_transform                                      58.24     0.00                 cpu0               1        NCHW      NCHW1c                                                                                          float32[1, 3, 224, 224], float32[3, 1, 1], float32[1, 3, 224, 224, 1]  7d81e14288856695                             
tvmgen_default_fused_nn_global_avg_pool2d                                      32.09     0.00                 cpu0   NCHW8c      1                                                                                                                                     float32[1, 256, 7, 7, 8], float32[1, 256, 1, 1, 8]  38ce81ad1dab4aac                             
tvmgen_default_fused_add_nn_relu_3                                             26.77     0.00                 cpu0               2                                                                                                           float32[1, 256, 7, 7, 8], float32[1, 256, 1, 1, 8], float32[1, 256, 7, 7, 8]  c55f62331739aaca                             
tvmgen_default_fused_layout_transform_4                                        25.76     0.00                 cpu0               3      NCHW4c      NCHW8c                                                                                                              float32[1, 128, 7, 7, 4], float32[1, 64, 7, 7, 8]  8c4709e6f29bc282                             
tvmgen_default_fused_layout_transform_5                                        21.45     0.00                 cpu0               2      NCHW8c    NCHW512c                                                                                                              float32[1, 64, 7, 7, 8], float32[1, 1, 7, 7, 512]  7b3d71f8119b8dc8                             
tvmgen_default_fused_layout_transform_2                                        20.40     0.00                 cpu0               1     NCHW16c   NCHW1024c                                                                                                        float32[1, 64, 14, 14, 16], float32[1, 1, 14, 14, 1024]  6dda5720a553f260                             
tvmgen_default_fused_nn_softmax                                                12.62     0.00                 cpu0               1                                                                                                                                                     float32[1, 1000], float32[1, 1000]  ca61e79ea24e53f0                             
tvmgen_default_fused_layout_transform_3                                        11.69     0.00                 cpu0               1     NCHW16c    NCHW512c                                                                                                             float32[1, 32, 7, 7, 16], float32[1, 1, 7, 7, 512]  14b60ba8378fe7fe                             
tvmgen_default_fused_layout_transform_nn_batch_flatten                          5.45     0.00                 cpu0               1      NCHW8c        NCHW                                                                                                                     float32[1, 256, 1, 1, 8], float32[1, 2048]  083c8dff13069b17                             
----------                                                                                                                                                                                                                                                                                                                                                              
Sum                                                                     14,09,161.01    99.90                                   86                                                                                                                                                                                                                                      
Total                                                                   14,10,572.30                          cpu0               1                                                                                                                                                                                                                                      

(1) reports 328ms while (2) reports 1.41s. What is the reason for this?