Relay 'conv2d' layer performance after auto-tuning same as fallback

Hi everyone,

I was trying to obtain the execution time for each one of the layers in resnet-18 (after auto-tuning). I obtain very similar results to the ones you obtain when running the whole architecture in the tutorial for the GPU (~1.10ms).

However, when I optimize a single layer and apply the best schedule I observe poor performance. For instance, for the last layer of resnet-18, I obtain 0.99ms (direct, not winograd), which is the same I observe in the fallback configuration for this layer.

In addition, when I check the log file, there seems to be a configuration that performs better

No: 76 GFLOPS: 871.64/871.64 result: MeasureResult(costs=(0.00026525944420131294,),

So I think that it is likely that the best configuration is not be being applied.

The code I execute is the following:

import os
import sys
import numpy as np

import tvm
import topi
import logging
from tvm import autotvm
from tvm import relay
from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
import tvm.contrib.graph_runtime as runtime

# Details about the target (CPU/GPU)
target = 'cuda'
target_host  = 'llvm'
batch_size = 1
dtype = 'float32'
# Set number of threads used for tuning based on the number of physical CPU cores on your machine.
num_threads = 8
os.environ["TVM_NUM_THREADS"] = str(num_threads)
# Set the input name of the graph
input_name = "data"

#log_file = "conv2d_direct512_cuda.log" 
#graph_opt_sch_file = "conv2d_direct512_cuda_opt.log"
log_file = "conv2d_wino512_cuda.log" 
graph_opt_sch_file = "conv2d_wino512_cuda_opt.log"
# Arguments to create task
data_shape = (batch_size, 512, 7, 7)
kernel_shape = (512, 512, 3, 3)
out_shape = (batch_size, 512, 7, 7)

data_shape_type = data_shape + ('float32',)
kernel_shape_type = kernel_shape + ('float32',)
kernel_size = (kernel_shape[2], kernel_shape[3])
strides = (1,1)
padding = (1,1,1,1)
dilation = (1,1)

# Convolution parameters
args= (('TENSOR', data_shape, 'float32'), ('TENSOR', kernel_shape, 'float32'), strides, padding, dilation, 'NCHW', 'float32')

# Workload for the task
workload = ('conv2d_nchw', data_shape_type, kernel_shape_type, strides, padding, dilation, 'NCHW', 'float32')

data = relay.var("data", shape=data_shape, dtype=dtype)
kernel = relay.var("kernel", shape=kernel_shape, dtype=dtype)

# Create a module with given target and extract task from it for auto-tuning
ctx = tvm.gpu()
out = relay.nn.conv2d(data, kernel, strides=strides, padding=padding, dilation=dilation, channels = kernel_shape[0], kernel_size = kernel_size, data_layout='NCHW', out_dtype=dtype)
mod = relay.Module.from_expr(out)

kernel_weights = tvm.nd.array(np.ones(kernel_shape, dtype=dtype), ctx)
dict_params = {'kernel': kernel_weights}

# task is a list an has several positions, autotuning has to get the position itself (e.g. task[0]) 
task = autotvm.task.extract_from_program(mod, target=target, target_host=target_host, params=dict_params, ops=(relay.op.nn.conv2d,))

# Change to winograd
task[0] = autotvm.task.create(task[0].name, task[0].args, task[0].target, task[0].target_host, 'winograd')
# Define type of auto-tuner
tuner_obj = XGBTuner(task[0])
print(task[0])

# logging config (for printing tuning log to the screen)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))

# We measure 10 times and take average to reduce variance.
measure_option = autotvm.measure_option(
    builder=autotvm.LocalBuilder(),
    runner=autotvm.LocalRunner(number=20, repeat=3, min_repeat_ms=150,timeout=4))

#n_trial = len(task.config_space)
#print(n_trial)
n_trial = 100

tuner_obj.tune(n_trial=n_trial,
           early_stopping = None,
           measure_option=measure_option,
           callbacks=[autotvm.callback.log_to_file(log_file)])

# inspect the best config
dispatch_context = autotvm.apply_history_best(log_file)
best_config = dispatch_context.query(task[0].target, task[0].workload)
print("\nBest config:")
print(best_config)

# Save optimal config to log file
text_file = open(graph_opt_sch_file, "w")
text_file.write(str(best_config))
text_file.close()

# compile kernels with history best records
with autotvm.apply_history_best(log_file):
    ctx = tvm.gpu()
    print("Compile...")

    with relay.build_config(opt_level=3):

        kernel_weights = tvm.nd.array(np.ones(kernel_shape, dtype=dtype), ctx)
        dict_params = {'kernel': kernel_weights}
        graph, lib, params = relay.build_module.build(mod, params = dict_params, target=target, target_host = target_host)
        #print(params)
        #print(dict_params)


    # BENCHMARKING: Measure time with and without optimizations
    # upload parameters to device
    input_name = "data"

    data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype),ctx)
    module = runtime.create(graph, lib, ctx)
    module.set_input(input_name, data_tvm)
    module.set_input(**params)

    # evaluate
    print("Evaluate inference time cost...")
    ftimer = module.module.time_evaluator("run", ctx, number=20, repeat=100)

    prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
    #print(prof_res)
    print("Mean inference time auto-tuning (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))


out1 = relay.nn.conv2d(data, kernel, strides=strides, padding=padding, dilation=dilation, channels = kernel_shape[0], kernel_size = kernel_size, data_layout='NCHW', out_dtype=dtype)
mod1 = relay.Module.from_expr(out1)
#print(mod)

ctx1 = tvm.gpu()
graph1, lib1, params1 = relay.build_module.build(mod1, params = dict_params, target=target)
#print(params)

# BENCHMARKING: Measure time with and without optimizations
input_name = "data"
data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype),ctx1)
module1 = runtime.create(graph1, lib1, ctx1)
module1.set_input(input_name, data_tvm)
module1.set_input(**params1)

# evaluate
print("Evaluate inference time cost...")
ftimer1 = module1.module.time_evaluator("run", ctx, number=20, repeat=100)

prof_res = np.array(ftimer1().results) * 1000  # convert to millisecond
#print(prof_res)
print("Mean inference time fallback (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))

Moreover, I have tested the same program using winograd convolution (just one convolutional layer) the time after applying the best configuration does match the log file (0.08ms):

No: 83 GFLOPS: 2891.20/2891.20 result: MeasureResult(costs=(7.997063408939294e-05,)

I was wondering if you know whether I am missing something. I would appreciate any help you can provide on this issue.

You just tuned for 100 trials? If so please try 3,000 or 4,000 trials.

1 Like

Hi @comaniac ,

Thank you for your prompt reply. I have updated the question a little bit so that things are more clear. Basically I use the same program and comment/uncomment one line, which is the following:

task[0] = autotvm.task.create(task[0].name, task[0].args, task[0].target, task[0].target_host,'winograd')

The log file matches the execution time when ‘winograd’ is set, however, for the ‘direct’ approach, the configuration does not seem to be applied, and the auto-tuning performance matches the fallback one (even though the log shows that there are configurations producing better results). As a sanity check, I will run the code using more iterations as you suggested, and I will let you know if the results change. Thanks again!

There are some possibilities:

  1. Try to use pick_best to identify the best config for each workload in a log file. AutoTVM will apply the best config over all tasks for the same workload. In other words, if you tune direct and winograd for the same conv2d workload and put them in the log file together, only the best one of them will be applied.

  2. It’s possible that the second build uses the cached one. You could try to add compile_engine.get().clear() before calling each relay.build to make sure you can get a real performance comparison between two configs.

  3. Also please note that you might not get the same performance as shown in the log file, because the latency in the log file was measured using TVM compiled single LLVM function. The graph runtime, on the other hand, has additional overheads.

Thanks a lot for your help @comaniac, I forgot to mention that I save the logs in different files to avoid problems. After 4000 trials, I get the same results for the ‘direct’ method, so it seems to be a problem when applying the best configuration. I will try the steps you mentioned above and see if I fix it :slight_smile: