Hello! I make a very simple test exploring if the data_alignment argument affects the performance of the intrinsic function:
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
def intrin_gemm(m, n, p, alignment=64):
a = te.placeholder((m, p), name='a')
b = te.placeholder((p, n), name='b')
k = te.reduce_axis((0, p), name='k')
c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k), name='c')
Ab = tvm.tir.decl_buffer(a.shape, a.dtype,
name="A",
offset_factor=1,
strides=[te.var("s1"), 1], data_alignment=alignment) # K, 1, e.g. (M, N, K) = (16, 16, 8) => 8
Bb = tvm.tir.decl_buffer(b.shape, b.dtype,
name="B",
offset_factor=1,
strides=[te.var("s2"), 1], data_alignment=alignment) # N, 1, e.g. => 16
Cb = tvm.tir.decl_buffer(c.shape, c.dtype,
name="C",
offset_factor=1,
strides=[te.var("s3"), 1], data_alignment=alignment) # N, 1, e.g. => 16
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
aa, bb = ins
cc = outs[0]
# Inline C
ib.emit(tvm.tir.call_extern("float32", "gemm_impl",
cc.access_ptr("w"),
aa.access_ptr("r"),
bb.access_ptr("r"),
m, n, p, aa.strides[0], bb.strides[0], cc.strides[0]))
return ib.get()
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
def gemm_impl():
cc_code = """
extern "C" int gemm_impl(float *cc, float *aa, float *bb, int m, int n, int p, int stride_a, int stride_b, int stride_c) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
float tmp = 0.0f;
for (int k = 0; k < p; ++k) {
tmp += aa[i * stride_a + k] * bb[k * stride_b + j];
}
cc[i * stride_c + j] = tmp;
}
}
return 0;
}
"""
from tvm.contrib import util, clang
temp = util.tempdir()
ll_path = temp.relpath("temp.ll")
# Create LLVM ir from c source code
ll_code = clang.create_llvm(cc_code, output=ll_path)
return ll_code
M, N, K = 256, 256, 256
A = te.placeholder((M, K), name='A')
B = te.placeholder((K, N), name='B')
k = te.reduce_axis((0, K), name='k')
C = te.compute((M, N), lambda i, j:
te.sum(A[i, k] * B[k, j], axis=k), name='C')
for alignment in [8, 16, 32, 64, 128, 256]:
s = te.create_schedule(C.op)
# print(tvm.lower(s, [A, B, C], simple_mode=True))
factor_x = 64
factor_y = 64
x, y = C.op.axis
z, = C.op.reduce_axis
xo, xi = s[C].split(x, factor=factor_x)
yo, yi = s[C].split(y, factor=factor_y)
s[C].reorder(xo, yo, xi, yi, z)
# print(tvm.lower(s, [A, B, C], simple_mode=True))
gemm = intrin_gemm(factor_x, factor_y, K, alignment)
s[C].tensorize(xi, gemm)
# print(tvm.lower(s, [A, B, C], simple_mode=True))
s[C].pragma(yo, "import_llvm", gemm_impl())
# print(tvm.lower(s, [A, B, C], simple_mode=True))
func = tvm.build(s, [A, B, C], target="llvm -mcpu=core-avx2", name="gemm")
from topi.util import get_const_tuple
dtype = A.dtype
ctx = tvm.context("cpu", 0)
a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx)
nd_arrays = [tvm.nd.array(a, ctx), tvm.nd.array(b, ctx), c]
# Measure a 'delta' time
run_number = 50
timer = func.time_evaluator(func.entry_name, ctx, number=run_number)
tcost_d = timer(*nd_arrays).mean
tvm.testing.assert_allclose(c.asnumpy(), np.dot(a, b), rtol=1e-3)
print("Average running time for alignment={} is {:.2f} us.".format(alignment, tcost_d * 1e6))
And the results are:
Average running time for alignment=8 is 15504.82 us.
Average running time for alignment=16 is 15055.16 us.
Average running time for alignment=32 is 15917.96 us.
Average running time for alignment=64 is 15728.37 us.
Average running time for alignment=128 is 15988.68 us.
[17:16:52] /home/moderato/Documents/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=128
[17:16:52] /home/moderato/Documents/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=128
[17:16:52] /home/moderato/Documents/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=128
Average running time for alignment=256 is 15010.73 us.
It looks like the runtimes are quite close to each other. I know this test might not well simulate the complex situations encountered in real applications, but is there an insight here about how data_alignment works and if it really affects the performance?
Thanks in advance!