Different performance of conv2d bewteen relay and te

Hi! I’m new to TVM and I was trying to use Ansor to tune conv2d in te or relay, and I noticed that te version is about 24x faster than relay version(<1ms vs 24ms). Which one is correct and how can I test the performance of a single op?

Here is my code.

Relay

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm import relay
from tvm.contrib import graph_executor

# config
target = tvm.target.Target("llvm -mcpu=skylake-avx512")
N, H, W, CO, CI, KH, KW, stride, padding = 1, 224, 224, 64, 3, 7, 7, (2, 2), (3, 3)

# relay function
data = relay.var("data", shape=(N, CI, H, W))
kernel = relay.var("kernel", shape=(CO, CI, KH, KW))
bias = relay.var("bias", shape=(1, CO, 1, 1))
conv = relay.nn.conv2d(data, kernel, stride, padding, dilation=1, channels=CO, kernel_size=(KH, KW))
tmp = relay.add(conv, bias)
out = relay.nn.relu(tmp)
fun = relay.Function([data, bias, kernel], out)

data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)

mod = tvm.IRModule.from_expr(fun)
params = {"weight": weight_np, "bias": bias_np}

# auto schedule
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

log_file = "conv2d_cpu.json"
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=1000,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)

tuner.tune(tune_option)

with auto_scheduler.ApplyHistoryBest(log_file):
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build_module.build(mod, target=target, params=params)


dev = tvm.cpu()
data_tvm = tvm.nd.array(data_np, device=dev)
module = graph_executor.GraphModule(lib["default"](dev))
module.set_input("data", data_tvm)

# Evaluate
print("Evaluate inference time cost...")
print(module.benchmark(dev, number=1, repeat=100, min_repeat_ms=500))

Te

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm.topi.testing import conv2d_nchw_python

# config
target = tvm.target.Target("llvm -mcpu=skylake-avx512")
N, H, W, CO, CI, KH, KW, stride, padding = 1, 224, 224, 64, 3, 7, 7, (2, 2), (3, 3)

@auto_scheduler.register_workload
def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
    bias = te.placeholder((1, CO, 1, 1), name="bias")
    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
    out = topi.nn.relu(conv + bias)
    return [data, kernel, bias, out]

# Create the search task
task = auto_scheduler.SearchTask(
    func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, stride, padding), target=target
)

log_file = "conv2d_cpu_te.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=1000,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)

task.tune(tune_option)
sch, args = task.apply_best(log_file)

func = tvm.build(sch, args, target)
# Check correctness
data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
bias_np = np.random.uniform(size=(1, CO, 1, 1)).astype(np.float32)
conv_np = conv2d_nchw_python(data_np, weight_np, stride, padding)
out_np = np.maximum(conv_np + bias_np, 0.0)

dev = tvm.cpu()
data_tvm = tvm.nd.array(data_np, device=dev)
weight_tvm = tvm.nd.array(weight_np, device=dev)
bias_tvm = tvm.nd.array(bias_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)

# Evaluate
evaluator = func.time_evaluator(func.entry_name, dev, number=1, repeat=100, min_repeat_ms=500)
print((np.median(evaluator(data_tvm, weight_tvm, bias_tvm, out_tvm).results) * 1000))

By checking the compute DAG, I find weight layout transformation occurs in the relay version, but there is no pre-transformation in the te flow. So there are two more questions.

  1. In Ansor’s paper, it says that the layouts of constant tensors, e.g., weight, will be re-written, because this can be done in compilation time. But in the above examples, I can’t find the layout changing in te version. How can I do pre-transformation for weight using te?
  2. In my examples, only relay performs weight layout transformation (It should happen during compilation time), but it spends more runtime (24ms) than te (<1ms). Am I using wrong time statistic API?

Thanks in advance.

Relay

# Computational DAG:
placeholder = PLACEHOLDER [1, 1, 224, 224, 3]
data_pad(i0, i1, i2, i3, i4) = tir.if_then_else(((((i2 >= 3) && (i2 < 227)) && (i3 >= 3)) && (i3 < 227)), placeholder[i0, i1, (i2 - 3), (i3 - 3), i4], 0f)
placeholder = PLACEHOLDER [2, 1, 7, 7, 3, 32]
conv2d_NCHWc(n, oc_chunk, oh, ow, oc_block) += (data_pad[n, floordiv(ic, 3), ((oh*2) + kh), ((ow*2) + kw), floormod(ic, 3)]*placeholder[oc_chunk, floordiv(ic, 3), kh, kw, floormod(ic, 3), oc_block])
placeholder = PLACEHOLDER [1, 2, 1, 1, 32]
T_add(ax0, ax1, ax2, ax3, ax4) = (conv2d_NCHWc[ax0, ax1, ax2, ax3, ax4] + placeholder[ax0, ax1, 0, 0, ax4])
T_relu(ax0, ax1, ax2, ax3, ax4) = max(T_add[ax0, ax1, ax2, ax3, ax4], 0f)

te

# Computational DAG:
data = PLACEHOLDER [1, 3, 224, 224]
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 3) && (i2 < 227)) && (i3 >= 3)) && (i3 < 227)), data[i0, i1, (i2 - 3), (i3 - 3)], 0f)
kernel = PLACEHOLDER [64, 3, 7, 7]
compute(nn, ff, yy, xx) += (pad_temp[nn, rc, ((yy*2) + ry), ((xx*2) + rx)]*kernel[ff, rc, ry, rx])
bias = PLACEHOLDER [1, 64, 1, 1]
T_add(ax0, ax1, ax2, ax3) = (compute[ax0, ax1, ax2, ax3] + bias[ax0, ax1, 0, 0])
compute(i0, i1, i2, i3) = max(T_add[i0, i1, i2, i3], 0f)

The major difference between your relay and te versions is the memory layout of the inputs. Your te uses (n)chw the relay version uses nchw. Changing this will probably explain most of the performance differences. auto_scheduler only optimizes the memory layout of layout free placeholders (bias and weight) in your example.

As far as I know there is no way of optimizing memory layouts when tuning te expressions.

You can also try this option for TE expression https://github.com/apache/tvm/blob/8f6543e9e6173cd45b678e91b5a637ff7f8e0e02/gallery/tutorial/auto_scheduler_matmul_x86.py#L69 and see how it works

However, this approach applies the layout transformation everytime. The layout transformation cannot be eliminated by tuning with a single TE expression. You can only use relay pass to eliminate the layout transformation.