Performance issue with convolutional layers with big filters (OpenCL)

TVM Version: 0.8 OS: Ubuntu 20.04

Issue : Auto-scheduler tuned model (more than 20k num_measure_trials) was much slower than other inference engines (Such as TFLite runtime).

My findings was that the time difference between the tvm and tflite widened, bigger the filter size.

Below is the test script that I used to compare the two:

import numpy as np
import tensorflow as tf
import tvm
from tvm import relay, auto_scheduler
from tensorflow.keras import Model, Input, layers, models, initializers
import tflite
from tvm.contrib import ndk
import tarfile
import time

###############################################################################
# RPC configuration
#
RPC_KEY = 'android'
RPC_PORT = 9190
RPC_TRACKER = '0.0.0.0'

###############################################################################
# Tune config
#
SKIP_TUNE = False
NUM_TRIALS = 20000

def generate_tflite_files(filter_sizes):
    def conv2d_model(num_filters):
        input = Input(shape=(300,300,3), name='input')
        x = layers.Conv2D(num_filters[0], (3, 3),
                    kernel_initializer=initializers.RandomUniform(),
                    bias_initializer=initializers.RandomUniform(),
                    activation='relu')(input)
        output = layers.Conv2D(num_filters[1], (3, 3),
                    kernel_initializer=initializers.RandomUniform(),
                    bias_initializer=initializers.RandomUniform(),
                    activation='relu')(x)
        model = Model(input, output)
        return model

    generated_files = []
    for filter_size in filter_sizes:
        model = conv2d_model(filter_size)
        # save as tflite
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        tflite_model = converter.convert()
        fname = f'conv2d_filter_{filter_size[0]}_{filter_size[1]}.tflite'
        generated_files.append(fname)
        with open(fname, 'wb') as f:
            f.write(tflite_model)

    return generated_files


###############################################################################
# Generate tflite files for multiple filter configurations
#
filter_sizes = [(3, 32), (32,96)]
tflite_files = generate_tflite_files(filter_sizes)



###############################################################################
# tune all tflite files
#
target_host = tvm.target.Target('llvm -mtriple=arm64-linux-android')
target = tvm.target.Target('opencl')

for tflite_file in tflite_files:
    record_filename = f'{tflite_file}.records'

    tf_model_buf = open(tflite_file, "rb").read()
    tflite_model = tflite.Model.GetRootAsModel(tf_model_buf, 0)
    mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=None, dtype_dict=None)

    from tvm.auto_scheduler.utils import request_remote
    print('Request remote...')
    remote = request_remote(RPC_KEY, RPC_TRACKER, RPC_PORT, timeout=5)
    dev = remote.cl()

    print(dev.max_clock_rate)
    print(dev.max_shared_memory_per_block)
    print(dev.max_thread_dimensions)
    print(dev.max_threads_per_block)
    print(dev.multi_processor_count)
    print(dev.warp_size)

    if not SKIP_TUNE:
        INT32_MAX = 2147483647
        hardware_params = auto_scheduler.HardwareParams(
            dev.multi_processor_count,
            16,
            64,
            dev.max_shared_memory_per_block,
            INT32_MAX,
            dev.max_threads_per_block,
            int(dev.warp_size / 4) if int(dev.warp_size / 4) > 1 else dev.warp_size,
            dev.warp_size,
        )

        # wait for remote session to timeout
        time.sleep(5)

        tasks, task_weights = auto_scheduler.extract_tasks(
                                    mod["main"], params, target,
                                    target_host=target_host,
                                    hardware_params=hardware_params)

        print('Begin tuning...')

        runner = auto_scheduler.RPCRunner(key=RPC_KEY, host=RPC_TRACKER, port=RPC_PORT, repeat=3, timeout=50, n_parallel=5)

        builder = auto_scheduler.LocalBuilder(build_func="ndk")
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=record_filename)
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=NUM_TRIALS,
            builder=builder,
            runner=runner,
            measure_callbacks=[auto_scheduler.RecordToFile(record_filename)],
            verbose = 1
        )

        print("Tuning...")
        tuner.tune(tune_option)

    auto_scheduler.load_records(record_filename)
    config = {'relay.backend.use_auto_scheduler':True}

    with auto_scheduler.ApplyHistoryBest(record_filename):
        print("Compiling...")
        with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=None):
            graph, lib, params = relay.build(mod, target=target, target_host=target_host, params=params)

    lib_file = tflite_file + '.so'
    graph_file = tflite_file + '.json'
    param_file = tflite_file + '.params'

    lib.export_library(lib_file, ndk.create_shared)

    with open(graph_file, 'w') as f:
        f.write(graph)
    with open(param_file, 'wb') as f:
        f.write(relay.save_param_dict(params))

    with tarfile.open(tflite_file + '.tar', 'w') as tar:
        tar.add(lib_file, arcname='mod.so')
        tar.add(param_file, arcname='mod.params')





        tar.add(graph_file, arcname='mod.json')

Above script will generate two tflite model files and two tvm tar file with different filter sizes. TVM models can be run with tvm rpc and TFLite model file can be run with an android app called “TFLite Lab” hosted on the Google Play Store (Link below)