TVM CUDA warp-level sync?

I wonder is it possible for TVM to support CUDA warp-level sync operations? For example, if I want to use shuffle intrinsics, what should I do? If not possible, then I have to use shared memory. But then TVM will generate syncthreads, which is an overkill. If I load and consume shared memory only in warp, i.e. no memory shared across warps, I don’t need syncthreads, right?

Let’s look at the following naive example:

import tvm
from tvm.te import hybrid

@hybrid.script
def demo_sync(indices):
    out = output_tensor((indices.shape[0],), 'int32')
    sm_i = allocate((128,), 'int32', 'shared')
    for b in bind('blockIdx.x', indices.shape[0] // 128):
        for y in range(4):
            for x in bind('threadIdx.x', 32):
                sm_i[y * 4 + x] = indices[b * 128 + y * 4 + x]
                out[y * 4 + x] = sm_i[y * 4 + 31 - x]
            # for i in range(32):
    return out

indices = tvm.te.placeholder((1024,), 'int32', 'indices')
out = demo_sync(indices)
sched = tvm.te.create_schedule(out.op)
f = tvm.build(sched, [indices, out], target='cuda')
print(f.imported_modules[0].get_source())

It will generate following cuda code:

extern "C" __global__ void default_function_kernel0(int* __restrict__ indices, int* __restrict__ demo_sync) {
  __shared__ int sm_i[128];
  for (int y = 0; y < 4; ++y) {
    __syncthreads();
    sm_i[(((y * 4) + ((int)threadIdx.x)))] = indices[((((((int)blockIdx.x) * 128) + (y * 4)) + ((int)threadIdx.x)))];
    __syncthreads();
    demo_sync[(((y * 4) + ((int)threadIdx.x)))] = sm_i[((((y * 4) + 31) - ((int)threadIdx.x)))];
  }
}

I think the two syncthreads is not necessary. Is it possible to fix it? Either by using shuffle instead, or just do not generate sync operations, or use syncwarp.

TVM has a warp memory abstraction. If you use allocate((128,), 'int32', 'warp'), TVM will put the data in thread local registers and then use shuffle operations to make the data available to other threads in the warp. Out can also use the shuffles directly if you want. I’m not sure how exactly to use warp shuffles in hybrid script, but you can grep the codebase for tvm_warp_shuffle.

1 Like