Auto-scheduler tuned OpenCL kernel produces Invalid Workgroup Size Error

TVM Version: 0.8 OS: Ubuntu 20.04

Auto-scheduler tuned OpenCL kernel produces Invalid Workgroup Size Error.

Below is the script that I used to reproduce the issue. The script basically creates a tflite model, tunes it and run until an error occurs. It tunes 10 times and runs to check for errors. I’ve been able to reproduce the issue in about 100 iterations or less.

import numpy as np
import tensorflow as tf
import tvm
from tvm import relay, auto_scheduler
from tensorflow.keras import Model, Input, layers, models, initializers
import tflite
from tvm.contrib import ndk
import tarfile
from tvm.driver.tvmc.runner import run_module
from tvm.auto_scheduler.utils import request_remote
import os
import time

###############################################################################
# RPC configuration
#
RPC_KEY = 'android'
RPC_PORT = 9190
RPC_TRACKER = '0.0.0.0'

###############################################################################
# Tune config
#
SKIP_TUNE = False
NUM_TRIALS = 20000
MIN_TRIALS = 10
EPOCH_CNT = NUM_TRIALS / MIN_TRIALS

def generate_tflite_file(N, H, W, CO, CI, KH, KW):
    def test_model(N, H, W, CO, CI, KH, KW):
        input = Input(shape=(CI, H, W), batch_size=N, name='input')
        x = layers.Conv2D(CO, (KH, KW),
                    kernel_initializer=initializers.RandomUniform(),
                    bias_initializer=initializers.RandomUniform(),
                    activation='relu')(input)
        x = layers.Conv2DTranspose(CO, (KH, KW),
                    kernel_initializer=initializers.RandomUniform(),
                    bias_initializer=initializers.RandomUniform(),
                    activation='relu')(x)
        x = layers.Conv2D(CO, (KH, KW),
                    kernel_initializer=initializers.RandomUniform(),
                    bias_initializer=initializers.RandomUniform(),
                    activation='relu')(x)
        output = layers.Conv2DTranspose(CO, (KH, KW),
                    kernel_initializer=initializers.RandomUniform(),
                    bias_initializer=initializers.RandomUniform(),
                    activation='relu')(x)
        model = Model(input, output)
        return model

    model = test_model(N,H,W,CO,CI,KH,KW)
    # save as tflite
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    params = '_'.join(str(x) for x in (N,H,W,CO,CI,KH,KW))
    tflite_file = f'invalid_ws_test_{params}.tflite'
    with open(tflite_file, 'wb') as f:
        f.write(tflite_model)

    return tflite_file


###############################################################################
# Generate a tflite file
#
tflite_file = generate_tflite_file(1,140,108,64,64,2,2)



###############################################################################
# tune and run until Invalid workgroup size error is generated
#
target_host = tvm.target.Target('llvm -mtriple=arm64-linux-android')
target = tvm.target.Target('opencl')

record_filename = f'{tflite_file}.records'

print('Request remote...')
remote = request_remote(RPC_KEY, RPC_TRACKER, RPC_PORT, timeout=3)
dev = remote.cl()

print(dev.max_clock_rate)
print(dev.max_shared_memory_per_block)
print(dev.max_thread_dimensions)
print(dev.max_threads_per_block)
print(dev.multi_processor_count)
print(dev.warp_size)

INT32_MAX = 2147483647
hardware_params = auto_scheduler.HardwareParams(
    dev.multi_processor_count,
    16,
    64,
    dev.max_shared_memory_per_block,
    INT32_MAX,
    dev.max_threads_per_block,
    int(dev.warp_size / 4) if int(dev.warp_size / 4) > 1 else dev.warp_size,
    dev.warp_size,
)

# wait for remote session to timeout
time.sleep(3)

epoch_cnt = 0
while epoch_cnt < EPOCH_CNT:
    print('===============================================================================')
    print(f'Starting Epoch #{epoch_cnt}')
    print('===============================================================================')

    tf_model_buf = open(tflite_file, "rb").read()
    tflite_model = tflite.Model.GetRootAsModel(tf_model_buf, 0)
    mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=None, dtype_dict=None)

    if not SKIP_TUNE:
        tasks, task_weights = auto_scheduler.extract_tasks(
                            mod["main"], params, target,
                            target_host=target_host,
                            hardware_params=hardware_params)

        runner = auto_scheduler.RPCRunner(key=RPC_KEY, host=RPC_TRACKER, port=RPC_PORT, repeat=1, timeout=50, n_parallel=1)

        builder = auto_scheduler.LocalBuilder(build_func="ndk")

        tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=record_filename)
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=MIN_TRIALS,
            builder=builder,
            runner=runner,
            measure_callbacks=[auto_scheduler.RecordToFile(record_filename)],
            verbose = 1
        )

        print("Tuning...")
        tuner.tune(tune_option)
        # save best records
        auto_scheduler.measure_record.distill_record_file(record_filename, record_filename+'.best')

    print("Compiling...")
    if os.path.isfile(record_filename):
        config = {'relay.backend.use_auto_scheduler':True}
        with auto_scheduler.ApplyHistoryBest(record_filename):
            with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=None):
                graph, lib, params = relay.build(mod, target=target, target_host=target_host, params=params)
    else:
        graph, lib, params = relay.build(mod, target=target, target_host=target_host, params=params)

    lib_file = tflite_file + '.so'
    graph_file = tflite_file + '.json'
    param_file = tflite_file + '.params'

    lib.export_library(lib_file, ndk.create_shared)

    with open(graph_file, 'w') as f:
        f.write(graph)
    with open(param_file, 'wb') as f:
        f.write(relay.save_param_dict(params))

    tvm_model_file = tflite_file + '.tar'
    with tarfile.open(tvm_model_file, 'w') as tar:
        tar.add(lib_file, arcname='mod.so')
        tar.add(param_file, arcname='mod.params')
        tar.add(graph_file, arcname='mod.json')

    # Run on target
    print(f'Running {tvm_model_file} on target device...')
    outputs, time = run_module(tvm_model_file, hostname=RPC_TRACKER, port=RPC_PORT, rpc_key=RPC_KEY,
        device='cl', inputs=None, fill_mode='random', repeat=1, profile=False)

    print(time)
    #print(outputs)
    epoch_cnt += 1

With little bit of more effort I was able to capture what auto-scheduler has ran and what the final kernel was.

Kernel compiled during tuning was:

__kernel void default_function_kernel0(__global float* restrict placeholder, __global float* restrict placeholder1, __global float* restrict compute) {
float compute_local[4];
__local float data_pad_shared[3408];
__local float placeholder_shared[1024];
for (int c_c_inner_init = 0; c_c_inner_init < 2; ++c_c_inner_init) {
compute_local[(c_c_inner_init)] = 0.000000e+00f;
compute_local[((c_c_inner_init + 2))] = 0.000000e+00f;
}
for (int dc_outer_outer = 0; dc_outer_outer < 4; ++dc_outer_outer) {
barrier(CLK_LOCAL_MEM_FENCE);
for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer < 7; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) < 3408) {
data_pad_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))))] = (((((1 <= (((((int)get_group_id(0)) & 1) * 70) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 71))) && ((((((int)get_group_id(0)) & 1) * 70) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 71)) < 140)) && (1 <= ((((((int)get_group_id(0)) & 63) >> 1) * 2) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 213) / 71)))) && (((((((int)get_group_id(0)) & 63) >> 1) * 2) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 213) / 71)) < 64)) ? placeholder[((((((((dc_outer_outer * 140112) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) / 213) * 8757)) + (((((int)get_group_id(0)) & 63) >> 1) * 278)) + (((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 213) / 71) * 139)) + ((((int)get_group_id(0)) & 1) * 70)) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 71)) - 140))] : 0.000000e+00f);
}
}
for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 < 2; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) < 256) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 280) + (((int)get_local_id(0)) >> 1)) < 512) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))) < 1024) {
placeholder_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))))] = placeholder1[((((((dc_outer_outer * 4096) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) >> 4) * 256)) + ((((int)get_group_id(0)) >> 6) * 64)) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) & 15) * 4)) + (((int)get_local_id(0)) & 3)))];
}
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int dc_outer_inner = 0; dc_outer_inner < 8; ++dc_outer_inner) {
for (int dh_outer_inner = 0; dh_outer_inner < 2; ++dh_outer_inner) {
for (int dc_inner = 0; dc_inner < 2; ++dc_inner) {
for (int dw_inner = 0; dw_inner < 2; ++dw_inner) {
for (int c_c_inner = 0; c_c_inner < 2; ++c_c_inner) {
compute_local[(c_c_inner)] = (compute_local[(c_c_inner)] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70) * 71)) + (dh_outer_inner * 71)) + dw_inner) + (((int)get_local_id(0)) % 70)))] * placeholder_shared[((((((((dc_outer_inner * 128) + (dc_inner * 64)) + ((((int)get_local_id(0)) / 140) * 8)) + (c_c_inner * 4)) + 3) - dw_inner) - (dh_outer_inner * 2)))]));
compute_local[((c_c_inner + 2))] = (compute_local[((c_c_inner + 2))] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70) * 71)) + (dh_outer_inner * 71)) + dw_inner) + (((int)get_local_id(0)) % 70)))] * placeholder_shared[((((((((dc_outer_inner * 128) + (dc_inner * 64)) + ((((int)get_local_id(0)) / 140) * 8)) + (c_c_inner * 4)) + 35) - dw_inner) - (dh_outer_inner * 2)))]));
}
}
}
}
}
}
for (int c_inner = 0; c_inner < 2; ++c_inner) {
compute[(((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) % 140) / 70) * 140)) + ((((int)get_group_id(0)) & 1) * 70)) + (((int)get_local_id(0)) % 70)))] = compute_local[(c_inner)];
compute[((((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) % 140) / 70) * 140)) + ((((int)get_group_id(0)) & 1) * 70)) + (((int)get_local_id(0)) % 70)) + 71680))] = compute_local[((c_inner + 2))];
}
}

Kernel compiled using the record written with above tuning results.

__kernel void fused_nn_conv2d_transpose_83_kernel0(__global float* restrict placeholder, __global float* restrict placeholder1, __global float* restrict compute) {

float compute_local[4];
__local float data_pad_shared[3408];
__local float placeholder_shared[1024];
for (int c_c_outer_inner_init = 0; c_c_outer_inner_init < 2; ++c_c_outer_inner_init) {
compute_local[(c_c_outer_inner_init)] = 0.000000e+00f;
compute_local[((c_c_outer_inner_init + 2))] = 0.000000e+00f;
}
for (int dc_outer_outer = 0; dc_outer_outer < 4; ++dc_outer_outer) {
barrier(CLK_LOCAL_MEM_FENCE);
for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer < 7; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) < 3408) {
data_pad_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))))] = (((((1 <= (((((int)get_group_id(0)) & 1) * 70) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_o
}
}
for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 < 2; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) < 256) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 280) + (((int)get_local_id(0)) >> 1)) < 512) {
if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))) < 1024) {
placeholder_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))))] = placeholder1[((((((dc_outer_outer * 4096) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_out
}
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int dc_outer_inner = 0; dc_outer_inner < 8; ++dc_outer_inner) {
for (int dh_outer_inner = 0; dh_outer_inner < 2; ++dh_outer_inner) {
for (int c_c_outer_inner = 0; c_c_outer_inner < 2; ++c_c_outer_inner) {
for (int dc_inner = 0; dc_inner < 2; ++dc_inner) {
for (int dw_inner = 0; dw_inner < 2; ++dw_inner) {
compute_local[(c_c_outer_inner)] = (compute_local[(c_c_outer_inner)] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70) * 71)) + (d
compute_local[((c_c_outer_inner + 2))] = (compute_local[((c_c_outer_inner + 2))] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70)
}
}
}
}
}
}
for (int c_inner = 0; c_inner < 2; ++c_inner) {
compute[(((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) % 1
compute[((((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) %
}
}

In case when final compiled program appeared independently on compilation stage, it is quite hard to take into account all conditions and limitations. We got such kind of problem for Metal, for example. And according to Apple, The maxTotalThreadsPerThreadgroup property is dependent on the device, the register usage of your compute kernel, and threadgroup memory usage. I.e. it depends on the compiler and how it perform memory management. There might be two cases

  1. We should know how generated kernel is reflected in the compiled kernel, we can calculate this during compilation of the model in TVM and we can understand if we have problems or not and select the required parameters for execution of the kernel. It is quite hard because compilers can be changed, almost all of them are close sourced and vendors usually do not expose much information about details.
  2. We should have an ability to come to the device, compile kernels and execute them. We might not know exact parameters, but big number of trials must help to find proper configuration. The only one assumption - we should believe that one kernel been compiled and executed successfully in standalone mode should be compiled and executed as a part of big number of kernels.

TVM tries to do its best with 1st but there is a tradeoff between performance and stability. The most stable approach assume that we have one thread executing naive version of the kernel.

The 2nd approach - exactly what autotuning does.

I looked into your script and do not understand the logic related to the epoch count. Autoscheduler divides all pointed trials to groups of measurements itself. Bunch of measurement is named round. The autoscheduler algo can be represented like

  • split whole topology to tasks
  • for each task perform one round of measurements to get baseline for cost model
  • select the most promising task and perform one round of tuning for this task

the latest item is repeated until we exceed the num_measure_trials

each round by default contains 64 measurements, you can change this by pointing another value of num_measures_per_round in auto_scheduler.TuningOptions

thus way, I recommend to

  1. remove loop for epoch
  2. point num_measure_trials = 500. This should be enough for 3 tasks to get quite good result
  3. (not necessary) point num_measures_per_round = 16. I prefer to point less values like 8 to speedup round-robin stage and concentrate on most promising kernels, but for 3 tasks this might not be important.
1 Like

Thanks for your reply.

The reason I split the tuning into small steps is to find the problem quickly because we can only detect problem when running the model. The actual problem occurred when tuned for more than 20000 in a single tune call like below:

auto_scheduler.TuningOptions(num_measure_trials=20000, ...)

Maybe I was not clear about the problem description.

I understand that an identical compute kernel can have different kernel work group size depending on the device and the OpenCL library. (Meaning different results from clGetKernelWorkGroupInfo depending on the device and/or opencl library) But the problem I have, occur on the same device I had tuned and ran.

Below are the sequence of steps when the problem occur.

  1. auto-scheduler compiles a selected “state” for a task.
  2. runs this “program” on the target device.
  3. the “program” runs on the target without any problem
  4. the measured result is recorded to a file (this is the best record)
  5. compiling with best record chooses the record written in step 4
  6. Running the tvm model compiled in step 5 gives CL_INVALID_WORK_GROUP_SIZE error.

So basically it seems that the compute kernel measured/tested during auto-scheduling and the compute kernel generated during the final compilation is different with the same task and same parameters.

Hello, @euntaik!

I tried to reproduce your issue with your script on my Xiaomi Poco X3 Pro and I didn’t see such problem with CL_INVALID_WORK_GROUP_SIZE… Which device are you using for tuning and running this model?