Competitive gemm/matmul example?

I need a batched matmul cuda kernel for square matrices of a fixed size (say batch x 512 x 512) as a starting place.

I have been following the code here:

https://github.com/apache/incubator-tvm/blob/master/topi/recipe/gemm/cuda_gemm_square.py

Unfortunately my code though is not nearly competitive in speed with pytorch matmul. It’s about 5x slower, and I don’t have a sense of what more I can do.

Roughly my code translation:


import tvm

# Setup
num_thread = 8
vthread = 2
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_z = tvm.thread_axis((0, num_thread), "threadIdx.z")
# Virtual threads to avoid write conflicts
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")

# Code
nn = 512
n = tvm.var('n')
n = tvm.convert(nn)
b = tvm.var('b')
m, l = n, n
A = tvm.placeholder((b, n, l), name='A')
B = tvm.placeholder((b, m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
k2 = tvm.reduce_axis((0, l), name='k2')

C = tvm.compute(
    (b, m, n),
    lambda bb, ii, jj: tvm.sum(A[bb, ii, k] * B[bb, jj, k], axis=k),
    name='C')
s = tvm.create_schedule(C.op)

# TVM scheduling
tk = 16
num_thread = 16
bt = 1
block_factor = 8 * num_thread


# Construct read caches
# Cache for max step 
A_shared = s.cache_read(A, "shared", [C])
A_local  = s.cache_read(A_shared, "local", [C])
B_shared = s.cache_read(B, "shared", [C])
B_local  = s.cache_read(B_shared, "local", [C])
C_local  = s.cache_write(C, "local")

# Split each axis into block axis, thread axis, and inner axis
b, x, y = s[C].op.axis

xb, xo = s[C].split(x, factor=block_factor)
yb, yo = s[C].split(y, factor=block_factor)
bo, bi = s[C].split(b, factor=bt)

t_x, xo = s[C].split(xo, nparts=vthread)
t_y, yo = s[C].split(yo, nparts=vthread)

xo, xi = s[C].split(xo, nparts=num_thread)
yo, yi = s[C].split(yo, nparts=num_thread)


# batch, block, thread, inner
s[C].reorder(bo, bi, xb, yb, t_x, t_y, xo, yo, xi, yi)

# Note that we bind yb to blockIdx.x instead of blockIdx.y
s[C].bind(t_x, thread_xz)
s[C].bind(t_y, thread_yz)

s[C].bind(bo, block_z)
s[C].bind(bi, thread_z)
s[C].bind(xb, block_y)
s[C].bind(yb, block_x)
s[C].bind(xo, thread_y)
s[C].bind(yo, thread_x)


# Move all the write computation inside the thread level.
s[C_local].compute_at(s[C], yo)

# Split out inner (local comp). Returns thread versus inner axes.
def local(op):
    b, yi, xi = s[op].op.axis
    k, = s[op].op.reduce_axis
    ko, ki = s[op].split(k, factor=tk)
    kt, ki = s[op].split(ki, factor = 1)
    s[op].reorder(b, ko, kt, ki, yi, xi)
    return ko, kt, ki
ko, kt, ki = local(C_local)


# Optimize read caches of A and B with cooperative fetching
def optimize_read_cache(shared, local, op, ki, ko, kt):
    s[shared].compute_at(s[op], ko)
    s[local].compute_at(s[op], kt)
    b, y, x = s[shared].op.axis

    # Note that we must split into block_size parts to reuse
    # the previous axis threads
    xo, xi = s[shared].split(x, nparts=num_thread)
    _, yi = s[shared].split(y, factor=num_thread * 4)
    yo, yi = s[shared].split(yi, nparts=num_thread)
    
    bo, bi = s[shared].split(b, nparts=1)

    #s[shared].reorder(bo, bi, yo, xo, yi, xi)
    s[shared].bind(bi, thread_z)
    s[shared].bind(xo, thread_x)
    s[shared].bind(yo, thread_y)
    s[shared].vectorize(xi)
    s[shared].double_buffer()

optimize_read_cache(A_shared, A_local, C_local, ki, ko, kt)
optimize_read_cache(B_shared, B_local, C_local, ki, ko, kt)

@srush, in case you are still working on this, I had the same question here Optimizing matrix multiplication for GPU and you might find the discussion relevant.

Thanks! I will check it out.

Would love to do some TVM / NLP projects. I feel like it is criminally underused (if it just worked slightly better).

1 Like

I agree, I had a lot of fun using it for longformer, and I would love to see it used for more NLP application.

(if it just worked slightly better)

I chatted with the TVM folks, @jknight, @jwfromm, @antinucleon, a few weeks ago and they mentioned they will consider providing better support for our use case. Things like a pip installable runtime and better tutorials can make TVM much easier to use. Better support for fp16 and getting GEMM performance closer to PyTorch’s are also important to make it practical.

2 Likes

Sorry to hear about your difficulties @srush.

Indeed we have a number of irons in the fire to address this:

  • Pip instalability of TVM is something we are still targetting for Q2 (so this month!)
  • We have an offer out to someone for a full time Developer Advocate role to help build docs (eg tutorials like “Help, my schedule is slow!”) among other activities in the TVM community.
  • We are publishing a paper on TVM auto-scheduling which would obviate the need for writing schedules altogether. This work is now being polished up for RFC and upstreaming. I’ve seen the preprint and the performance results are quite extraordinary! Stay tuned!

Let us know if you have additional ideas as well and we’ll see what we can do on our side!

2 Likes

Thanks, @jknight for the updates. This is great.

Let us know if you have additional ideas as well and we’ll see what we can do on our side!

If I have to mention something, it would be supporting fp16 using half2 which usually leads to compute savings, not just memory savings.

I’ll make sure to bring up fp16 in the GPU themed TVM meetup that I’ll start scheduling shortly.

CC @YuanLin

1 Like

I think your docs are pretty good. But I personally only care about three things:

  • Fast MM sample code (as fast as built in)
  • Fast LSTM sample code (as fast as CuDNN)
  • pip install

Without these it’s really hard to build non-toy applications.

1 Like