[CUDA Codegen] could it generate warp shuffle instructions?

Hi all, I used tvm to generate code for reduce_sum on cuda.

It generates this code, seems it uses shared_memory, but seems the more efficient way is to use warp shuffle instructions, could tvm support it?

extern "C" __global__ void __launch_bounds__(1024) tvmgen_default_fused_add_nn_relu_sum_kernel0(float* __restrict__ p0, float* __restric
t__ p1, float* __restrict__ T_relu_red) {
  float T_relu_red_rf[1];
  __shared__ float red_buf0[1024];
  T_relu_red_rf[0] = 0.000000e+00f;
  if (((int)threadIdx.x) < 4) {
    T_relu_red_rf[0] = (T_relu_red_rf[0] + max((p0[((int)threadIdx.x)] + p1[((int)threadIdx.x)]), 0.000000e+00f));
  }
  __syncthreads();
  ((volatile float*)red_buf0)[((int)threadIdx.x)] = T_relu_red_rf[0];
  __syncthreads();
  if (((int)threadIdx.x) < 512) {
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((
int)threadIdx.x) + 512)]);
  }
  __syncthreads();
  if (((int)threadIdx.x) < 256) {
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((
int)threadIdx.x) + 256)]);
  }
  __syncthreads();
  if (((int)threadIdx.x) < 128) {
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((
int)threadIdx.x) + 128)]);
  }
  __syncthreads();
  if (((int)threadIdx.x) < 64) {
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((
int)threadIdx.x) + 64)]);
  }
  __syncthreads();
  if (((int)threadIdx.x) < 32) {
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((
int)threadIdx.x) + 32)]);
  }
  __syncthreads();
  if (((int)threadIdx.x) < 16) {
    float w_16_0 = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((int)threadIdx.x) + 16)]);
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = w_16_0;
    float w_8_0 = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((int)threadIdx.x) + 8)]);
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = w_8_0;
    float w_4_0 = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((int)threadIdx.x) + 4)]);
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = w_4_0;
    float w_2_0 = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((int)threadIdx.x) + 2)]);
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = w_2_0;
    float w_1_0 = (((volatile float*)red_buf0)[((int)threadIdx.x)] + ((volatile float*)red_buf0)[(((int)threadIdx.x) + 1)]);
    ((volatile float*)red_buf0)[((int)threadIdx.x)] = w_1_0;
  }
  __syncthreads();
  if (((int)threadIdx.x) == 0) {
    T_relu_red[0] = ((volatile float*)red_buf0)[0];
  }
}

@yzh119 has experience in those intrinsics

@freshbird2023 yes TVM support generating shuffle-down primitives for reduction.

The reason you get this is that the reduction extent exceeds 32 which needs cross warp communication, and a shared memory was allocated for that.

To use warp-level reduction you can split the loop by factor 32 and bind that loop with extent 32 with threadIdx.x.

Below is an example:

import tvm
from tvm.script import tir as T

@T.prim_func
def reduce_sum(A: T.Buffer[(1024,), "float32"], B: T.Buffer[(1,), "float32"]) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    for i, j in T.grid(1, 1024):
        with T.block("reduce_sum"):
            vi, vj = T.axis.remap("SR", [i, j])
            with T.init():
                B[0] = T.float32(0)
            B[0] = B[0] + A[vj]


mod = tvm.IRModule.from_expr(reduce_sum)
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce_sum")
i, j = sch.get_loops(blk)
jo, ji = sch.split(j, [32, 32])
sch.bind(ji, "threadIdx.x")
sch.unroll(jo)
sch.bind(i, "blockIdx.x")
f = tvm.build(sch.mod["main"], target="cuda")
print(f.imported_modules[0].get_source())

and the generated code uses warp shuffle intrinsics:

extern "C" __global__ void __launch_bounds__(32) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
  float in_thread_0[1];
  uint mask[1];
  float t0[1];
  in_thread_0[0] = 0.000000e+00f;
  in_thread_0[0] = (in_thread_0[0] + A[((int)threadIdx.x)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 32)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 64)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 96)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 128)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 160)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 192)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 224)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 256)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 288)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 320)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 352)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 384)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 416)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 448)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 480)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 512)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 544)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 576)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 608)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 640)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 672)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 704)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 736)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 768)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 800)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 832)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 864)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 896)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 928)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 960)]);
  in_thread_0[0] = (in_thread_0[0] + A[(((int)threadIdx.x) + 992)]);
  mask[0] = __activemask();
  t0[0] = __shfl_down_sync(mask[0], in_thread_0[0], 16, 32);
  in_thread_0[0] = (in_thread_0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], in_thread_0[0], 8, 32);
  in_thread_0[0] = (in_thread_0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], in_thread_0[0], 4, 32);
  in_thread_0[0] = (in_thread_0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], in_thread_0[0], 2, 32);
  in_thread_0[0] = (in_thread_0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], in_thread_0[0], 1, 32);
  in_thread_0[0] = (in_thread_0[0] + t0[0]);
  in_thread_0[0] = __shfl_sync(mask[0], in_thread_0[0], 0, 32);
  B[0] = in_thread_0[0];
}

thank you! I learned that for reduce extent that is larger than 32, we can use recursive warp reduce to do reduction in a block, could tvm support this?

Yes TVM supports it, that’s why we have rfactor primitive, which helps us factorize reduction into several stages.

Here is an example of using rfactor:

import tvm
import numpy as np
from tvm.script import tir as T

@T.prim_func
def reduce_sum(A: T.Buffer[(1024,), "float32"], B: T.Buffer[(1,), "float32"]) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    for i, j in T.grid(1, 1024):
        with T.block("reduce_sum"):
            vi, vj = T.axis.remap("SR", [i, j])
            with T.init():
                B[0] = T.float32(0)
            B[0] = B[0] + A[vj]


mod = tvm.IRModule.from_expr(reduce_sum)
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce_sum")
i, j = sch.get_loops(blk)
jo, ji = sch.split(j, [32, 32])
rf_blk = sch.rfactor(jo, 1)
sch.compute_at(rf_blk, i)
print(sch.mod["main"].script())
# schedule rf_blk
sch.set_scope(rf_blk, 0, "shared")
i, jo, ji = sch.get_loops(rf_blk)
sch.bind(ji, "threadIdx.x")
sch.bind(jo, "threadIdx.y")
sch.bind(i, "blockIdx.x")
# schedule blk
i, j = sch.get_loops(blk)
sch.bind(j, "threadIdx.x")
f = tvm.build(sch.mod["main"], target="cuda")

dev = tvm.gpu(0)
X_nd = tvm.nd.array(np.arange(1024).astype("float32"), dev)  # [0, 1, ... ,1023]
Y_nd = tvm.nd.array(np.zeros(1).astype("float32"), dev)
print(f.imported_modules[0].get_source())

f(X_nd, Y_nd)
print(Y_nd)

The output of first print function is:

@T.prim_func
def main(A: T.Buffer((1024,), "float32"), B: T.Buffer((1,), "float32")):
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    # with T.block("root"):
    B_rf = T.alloc_buffer((1, 32))
    for i in range(1):
        for ax0, ax1 in T.grid(32, 32):
            with T.block("reduce_sum_rf"):
                vj_0 = T.axis.spatial(32, ax0)
                vi = T.axis.spatial(1, 0)
                vj_1 = T.axis.reduce(32, ax1)
                T.reads(A[vj_0 * 32 + vj_1])
                T.writes(B_rf[0, vj_0])
                with T.init():
                    B_rf[0, vj_0] = T.float32(0)
                B_rf[0, vj_0] = B_rf[0, vj_0] + A[vj_0 * 32 + vj_1]
        for j_0 in range(32):
            with T.block("reduce_sum"):
                vj_0, vi = T.axis.remap("RS", [j_0, i])
                T.reads(B_rf[0, vj_0])
                T.writes(B[0])
                with T.init():
                    B[0] = T.float32(0)
                B[0] = B[0] + B_rf[0, vj_0]

which generates an intermediate buffer B_rf that stores the intra-warp aggregation result. We later specify the scope of B_rf to be shared (this is necessary because the aggregation result are computed by different warps in the first stage, and we need the buffer to be visible for all warps, however this consumes much less shared memory compared to single stage reduction, because we only need to store one aggregation result per warp).

The output of second print function is:

extern "C" __global__ void __launch_bounds__(1024) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
  float red_buf0[1];
  __shared__ float B_rf_shared[32];
  float red_buf0_1[1];
  uint mask[1];
  float t0[1];
  red_buf0[0] = A[((((int)threadIdx.y) * 32) + ((int)threadIdx.x))];
  mask[0] = (__activemask() & ((uint)(0 << (((int)threadIdx.y) * 32))));
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 16, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  red_buf0[0] = __shfl_sync(mask[0], red_buf0[0], (((int)threadIdx.y) * 32), 32);
  if (((int)threadIdx.x) == 0) {
    B_rf_shared[((int)threadIdx.y)] = red_buf0[0];
  }
  __syncthreads();
  uint mask_1[1];
  float t0_1[1];
  red_buf0_1[0] = B_rf_shared[((int)threadIdx.x)];
  mask_1[0] = (__activemask() & ((uint)(0 << (((int)threadIdx.y) * 32))));
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  red_buf0_1[0] = __shfl_sync(mask_1[0], red_buf0_1[0], (((int)threadIdx.y) * 32), 32);
  if (((int)threadIdx.x) == 0) {
    B[0] = red_buf0_1[0];
  }
}

which exactly implements the “recursive” warp shuffle you mentioned.

and the compute result is [523776.] which equals 1+…+1023.

Have fun with TVM!

1 Like

Hi @freshbird2023, we expect this PR https://github.com/apache/tvm/pull/15327 to help ease a lot of things. So after it merging in (update: already merged now), we can write TIR code like


@T.prim_func
def main(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [1, 1024], dtype="float32")
    B = T.match_buffer(b, [1], dtype="float32")
    for i in T.serial(0, 1):
        for k in T.thread_binding(0, 1024, thread="threadIdx.x"):
            with T.block("B"):
                vi, vk = T.axis.remap("SR", [i, k])
                T.reads([A[vi, vk]])
                T.writes([B[vi]])
                with T.init():
                    B[vi] = T.float32(0)
                B[vi] = B[vi] + A[vi, vk]

and directly run tvm.build on it. This will give the generated CUDA code as below:

extern "C" __global__ void __launch_bounds__(1024) default_function_kernel(float* __restrict__ A, float* __restrict__ B) {
  float red_buf0[1];
  uint mask[1];
  float t0[1];
  float red_buf0_1[1];
  uint mask_1[1];
  float t0_1[1];
  __shared__ float red_buf_staging[32];
  red_buf0_1[0] = A[((int)threadIdx.x)];
  mask_1[0] = __activemask();
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32);
  red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
  red_buf0_1[0] = __shfl_sync(mask_1[0], red_buf0_1[0], 0, 32);
  if ((((int)threadIdx.x) % 32) == 0) {
    red_buf_staging[(((int)threadIdx.x) >> 5)] = red_buf0_1[0];
  }
  __syncthreads();
  if (((int)threadIdx.x) < 32) {
    red_buf0[0] = red_buf_staging[((int)threadIdx.x)];
  }
  mask[0] = __activemask();
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 16, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  red_buf0[0] = __shfl_sync(mask[0], red_buf0[0], 0, 32);
  if (((int)threadIdx.x) == 0) {
    B[0] = red_buf0[0];
  }
}
4 Likes

I also encountered the same problem when searching for reduce_sum on Ansor. The result of Ansor is 30x slower than torch.sum. I printed the cuda code found and found that its implementation is very ordinary, and printed Get 50 programs to measure. Is this a bug of Ansor?

Please refer to my reply in Github issues: https://github.com/apache/tvm/issues/15342#issuecomment-1641079322

Hi, I have a question: in meta scheduler, do we still need to write some specific schedules to implement efficient kernels?

I’m tuning a kernel calculating GEMV. Here is my code, is that right or efficient enough? I’m afraid it doesn’t use the full power of meta scheduler.

def matmul_layer(m, n, k):
    A = te.placeholder((m, k), name="A",  dtype="float16")
    B = te.placeholder((k, n), name="B", dtype="float16")
    k = te.reduce_axis((0, k), name="k")
    C = te.compute((m, n), lambda m, n: te.sum(A[m, k].astype("float32") * B[k, n].astype("float32"), axis=k), name="C")
    D = te.compute((m, n), lambda i, j: C[i, j].astype("float16"))
    f = te.create_prim_func([A, B, D])
    mod = tvm.IRModule.from_expr(f)
    return mod

m, n, k = 15360, 1, 5120

mod = matmul_layer(m, n, k)
target = tvm.target.Target("nvidia/nvidia-v100", host='llvm')
database = ms.tune_tir(
    mod=mod,
    target=target,
    max_trials_global=512,
    num_trials_per_iter=64,
    space="cuda",
    work_dir="./tune_tmp",
)

sch = ms.tir_integration.compile_tir(database, mod, target)
f = tvm.build(sch.mod["main"], target=target)
print(f.imported_modules[0].get_source())

@yzh119 could you give some advice ?