[Autoscheduler] Performance with small difference in number of channels

Hi

I’ve been using Autoscheduler to tune various convolutional layers and have occasionally found large differences in performance between a standard 2D convolution of dimension (out_channels, in_channels, H, W) and a small variant of it with dimension (2 + out_channels, in_channels, H, W).

For example, the script below performs autotuning for a layer of Alexnet with 384 output channels, and a corresponding layer with 386 output channels. The benchmarking results are as follows:

Channels Time
384 0.000200
386 0.001026

I’m aware that performance can suffer in some scenarios similar to this due to so-called tile quantization, but wasn’t expecting to see so stark a difference.

Are there any suggested strategies for improving performance here?

Thanks!

Setup details:

  • Target: ‘llvm -mcpu=skylake-avx512 -num-cores 2’
  • Hardware: AWS c5.4xlarge

Script:

import numpy as np
import os
import shutil
import tvm
from tvm import auto_scheduler
from tvm import relay
from tvm.contrib import graph_executor
from tvm.ir.module import IRModule
from tvm.target.target import Target
import tvm.testing
import tvm.topi.testing


KEY_WEIGHT = 'w'
KEY_DATA = 'x'


def eval(lib, target, inputs):
    dev = tvm.device(target.kind.name, 0)
    mod = graph_executor.GraphModule(lib["default"](dev))
    for k, v in inputs.items():
        mod.set_input(k, v)
    evaluator = mod.benchmark(dev, number=500)
    print('%f' % (evaluator.mean))
    return evaluator.mean


def autoscheduler_tune(mod, params, target, workdir, num_trials_total=20000, num_trials_per_iter=64):
    tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

    log_file = os.path.join(workdir, 'autoscheduler_log.json')
    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=num_trials_total,
        num_measures_per_round=num_trials_per_iter,
        runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True),
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

    tuner.tune(tune_option)

    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
            lib = relay.build(mod, target=target, params=params)
    return lib


def run(dtype, scale, dshape, kshape, padding, groups, dilation, channels, kernel_size, workdir):
    kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
    data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
    dkernel = tvm.topi.testing.dilate_python(kernel, (1, 1) + dilation)
    x = relay.var(KEY_DATA, shape=dshape, dtype=dtype)
    w = relay.var(KEY_WEIGHT, shape=kshape, dtype=dtype)
    y = relay.nn.conv2d(
        x,
        w,
        padding=padding,
        dilation=dilation,
        groups=groups,
        channels=channels,
        kernel_size=kernel_size,
    )
    func = relay.Function([x, w], y)
    params = {KEY_WEIGHT: dkernel}
    inputs = {KEY_DATA: data}

    mod = IRModule.from_expr(func)
    target = Target('llvm -mcpu=skylake-avx512 -num-cores 2')

    if os.path.isdir(workdir):
        shutil.rmtree(workdir)
    os.makedirs(workdir)

    tuned_mod = autoscheduler_tune(mod, params, target, workdir)
    return eval(tuned_mod, target, inputs)


if __name__ == "__main__":
    dtype = "float32"
    scale = 1
    hw = 13
    in_channels = 192
    batch_size = 1
    dshape = (batch_size, in_channels, hw, hw)
    kernel_size = (3, 3)
    padding = (1, 1)
    groups = 1
    dilation = (1, 1)

    overall_workdir = '/tmp/tune'
    results = []
    oc = [384, 386]
    for out_channels in oc:
        kshape = (out_channels, in_channels,) + kernel_size
        workdir = os.path.join(overall_workdir, str(out_channels))
        res = run(dtype, scale, dshape, kshape, padding, groups, dilation, out_channels, kernel_size, workdir)
        results.append(res)

    for out_channels, time in zip(oc, results):
        print('%d: %f' % (out_channels, time))

Auto-scheduler selects tile sizes by using factors of the channel dimension. 386 = 2x193, so it only has two valid tile sizes, which means auto-scheduler cannot do tiling optimization effectively. You can fix this by manually padding the channel dimension.

This is very helpful. Thank you!

@merrymercy , could you point me to where in the TVM codebase I can find an enumeration of valid tile sizes for a given channel size etc.?

Thank you for any help!