Does data_alignment in decl_buffer affect performance?

Hello! I make a very simple test exploring if the data_alignment argument affects the performance of the intrinsic function:

from __future__ import absolute_import, print_function

import tvm
from tvm import te
import numpy as np

def intrin_gemm(m, n, p, alignment=64):
    a = te.placeholder((m, p), name='a')
    b = te.placeholder((p, n), name='b')
    k = te.reduce_axis((0, p), name='k')
    c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k), name='c')
    Ab = tvm.tir.decl_buffer(a.shape, a.dtype,
                         name="A",
                         offset_factor=1,
                         strides=[te.var("s1"), 1], data_alignment=alignment) # K, 1, e.g. (M, N, K) = (16, 16, 8) => 8
    Bb = tvm.tir.decl_buffer(b.shape, b.dtype,
                         name="B",
                         offset_factor=1,
                         strides=[te.var("s2"), 1], data_alignment=alignment) # N, 1, e.g. => 16
    Cb = tvm.tir.decl_buffer(c.shape, c.dtype,
                         name="C",
                         offset_factor=1,
                         strides=[te.var("s3"), 1], data_alignment=alignment) # N, 1, e.g. => 16
    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        aa, bb = ins
        cc = outs[0]
        # Inline C
        ib.emit(tvm.tir.call_extern("float32", "gemm_impl",
                                cc.access_ptr("w"),
                                aa.access_ptr("r"),
                                bb.access_ptr("r"),
                                m, n, p, aa.strides[0], bb.strides[0], cc.strides[0]))
        return ib.get()
    return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})

def gemm_impl():
    cc_code = """
      extern "C" int gemm_impl(float *cc, float *aa, float *bb, int m, int n, int p, int stride_a, int stride_b, int stride_c) {
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                float tmp = 0.0f;
                for (int k = 0; k < p; ++k) {
                    tmp += aa[i * stride_a + k] * bb[k * stride_b + j]; 
                }
                cc[i * stride_c + j] = tmp;
            }
        }
        return 0;
      }
    """
    from tvm.contrib import util, clang
    temp = util.tempdir()
    ll_path = temp.relpath("temp.ll")
    # Create LLVM ir from c source code
    ll_code = clang.create_llvm(cc_code, output=ll_path)
    return ll_code

M, N, K = 256, 256, 256
A = te.placeholder((M, K), name='A')
B = te.placeholder((K, N), name='B')
k = te.reduce_axis((0, K), name='k')
C = te.compute((M, N), lambda i, j:
                te.sum(A[i, k] * B[k, j], axis=k), name='C')

for alignment in [8, 16, 32, 64, 128, 256]:
    s = te.create_schedule(C.op)
    # print(tvm.lower(s, [A, B, C], simple_mode=True))

    factor_x = 64
    factor_y = 64
    x, y = C.op.axis
    z, = C.op.reduce_axis
    xo, xi = s[C].split(x, factor=factor_x)
    yo, yi = s[C].split(y, factor=factor_y)
    s[C].reorder(xo, yo, xi, yi, z)
    # print(tvm.lower(s, [A, B, C], simple_mode=True))

    gemm = intrin_gemm(factor_x, factor_y, K, alignment)
    s[C].tensorize(xi, gemm)
    # print(tvm.lower(s, [A, B, C], simple_mode=True))

    s[C].pragma(yo, "import_llvm", gemm_impl())
    # print(tvm.lower(s, [A, B, C], simple_mode=True))

    func = tvm.build(s, [A, B, C], target="llvm -mcpu=core-avx2", name="gemm")

    from topi.util import get_const_tuple
    dtype = A.dtype
    ctx = tvm.context("cpu", 0)
    a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
    b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
    c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx)
    nd_arrays = [tvm.nd.array(a, ctx), tvm.nd.array(b, ctx), c]

    # Measure a 'delta' time
    run_number = 50
    timer = func.time_evaluator(func.entry_name, ctx, number=run_number)
    tcost_d = timer(*nd_arrays).mean
    tvm.testing.assert_allclose(c.asnumpy(), np.dot(a, b), rtol=1e-3)
    print("Average running time for alignment={} is {:.2f} us.".format(alignment, tcost_d * 1e6))

And the results are:

Average running time for alignment=8 is 15504.82 us.
Average running time for alignment=16 is 15055.16 us.
Average running time for alignment=32 is 15917.96 us.
Average running time for alignment=64 is 15728.37 us.
Average running time for alignment=128 is 15988.68 us.
[17:16:52] /home/moderato/Documents/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=128
[17:16:52] /home/moderato/Documents/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=128
[17:16:52] /home/moderato/Documents/incubator-tvm/src/tir/transforms/arg_binder.cc:95: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=128
Average running time for alignment=256 is 15010.73 us.

It looks like the runtimes are quite close to each other. I know this test might not well simulate the complex situations encountered in real applications, but is there an insight here about how data_alignment works and if it really affects the performance?

Thanks in advance!

@tqchen @kevinthesun

Hi, have you found the answer? Thanks!