I need a batched matmul cuda kernel for square matrices of a fixed size (say batch x 512 x 512) as a starting place.
I have been following the code here:
https://github.com/apache/incubator-tvm/blob/master/topi/recipe/gemm/cuda_gemm_square.py
Unfortunately my code though is not nearly competitive in speed with pytorch matmul
. It’s about 5x slower, and I don’t have a sense of what more I can do.
Roughly my code translation:
import tvm
# Setup
num_thread = 8
vthread = 2
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_z = tvm.thread_axis((0, num_thread), "threadIdx.z")
# Virtual threads to avoid write conflicts
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
# Code
nn = 512
n = tvm.var('n')
n = tvm.convert(nn)
b = tvm.var('b')
m, l = n, n
A = tvm.placeholder((b, n, l), name='A')
B = tvm.placeholder((b, m, l), name='B')
k = tvm.reduce_axis((0, l), name='k')
k2 = tvm.reduce_axis((0, l), name='k2')
C = tvm.compute(
(b, m, n),
lambda bb, ii, jj: tvm.sum(A[bb, ii, k] * B[bb, jj, k], axis=k),
name='C')
s = tvm.create_schedule(C.op)
# TVM scheduling
tk = 16
num_thread = 16
bt = 1
block_factor = 8 * num_thread
# Construct read caches
# Cache for max step
A_shared = s.cache_read(A, "shared", [C])
A_local = s.cache_read(A_shared, "local", [C])
B_shared = s.cache_read(B, "shared", [C])
B_local = s.cache_read(B_shared, "local", [C])
C_local = s.cache_write(C, "local")
# Split each axis into block axis, thread axis, and inner axis
b, x, y = s[C].op.axis
xb, xo = s[C].split(x, factor=block_factor)
yb, yo = s[C].split(y, factor=block_factor)
bo, bi = s[C].split(b, factor=bt)
t_x, xo = s[C].split(xo, nparts=vthread)
t_y, yo = s[C].split(yo, nparts=vthread)
xo, xi = s[C].split(xo, nparts=num_thread)
yo, yi = s[C].split(yo, nparts=num_thread)
# batch, block, thread, inner
s[C].reorder(bo, bi, xb, yb, t_x, t_y, xo, yo, xi, yi)
# Note that we bind yb to blockIdx.x instead of blockIdx.y
s[C].bind(t_x, thread_xz)
s[C].bind(t_y, thread_yz)
s[C].bind(bo, block_z)
s[C].bind(bi, thread_z)
s[C].bind(xb, block_y)
s[C].bind(yb, block_x)
s[C].bind(xo, thread_y)
s[C].bind(yo, thread_x)
# Move all the write computation inside the thread level.
s[C_local].compute_at(s[C], yo)
# Split out inner (local comp). Returns thread versus inner axes.
def local(op):
b, yi, xi = s[op].op.axis
k, = s[op].op.reduce_axis
ko, ki = s[op].split(k, factor=tk)
kt, ki = s[op].split(ki, factor = 1)
s[op].reorder(b, ko, kt, ki, yi, xi)
return ko, kt, ki
ko, kt, ki = local(C_local)
# Optimize read caches of A and B with cooperative fetching
def optimize_read_cache(shared, local, op, ki, ko, kt):
s[shared].compute_at(s[op], ko)
s[local].compute_at(s[op], kt)
b, y, x = s[shared].op.axis
# Note that we must split into block_size parts to reuse
# the previous axis threads
xo, xi = s[shared].split(x, nparts=num_thread)
_, yi = s[shared].split(y, factor=num_thread * 4)
yo, yi = s[shared].split(yi, nparts=num_thread)
bo, bi = s[shared].split(b, nparts=1)
#s[shared].reorder(bo, bi, yo, xo, yi, xi)
s[shared].bind(bi, thread_z)
s[shared].bind(xo, thread_x)
s[shared].bind(yo, thread_y)
s[shared].vectorize(xi)
s[shared].double_buffer()
optimize_read_cache(A_shared, A_local, C_local, ki, ko, kt)
optimize_read_cache(B_shared, B_local, C_local, ki, ko, kt)