Optimizing Reduction Initialization in Generated CUDA Code

Hi,

I’m working with an IRModule (see below) that generates CUDA code for matrix multiplication. In the generated code, there’s an initialization block that only executes when ((ax0_0 * 8) + ax0_1) == 0 :

for (int ax0_0 = 0; ax0_0 < 640; ++ax0_0) {
    // ...
    for (int ax0_1 = 0; ax0_1 < 8; ++ax0_1) {
        // ...
        if (((ax0_0 * 8) + ax0_1) == 0) {
            compute_local[0] = 0.000000e+00f;
            compute_local[4] = 0.000000e+00f;
            // ... (all other initializations)
        }
        // ...
    }
    // ...
}

This initialization could be performed earlier since the condition only evaluates to true when all outer loop variables are zero. The jump instructions compiled from the if statement will cause the thread execution to slow down. In my test, about 1/5 of the kernel execution time is wasted.

I know that decompose_reduction can help move the initialization block earlier, but it generates two separate kernels. The startup overhead of two kernels (requiring the host to send thread blocks to the SM twice) is greater than that of a single kernel.

Is there a way to optimize this initialization without splitting into multiple kernels?

Here’s the relevant part of the IRModule:

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((5120, 5120), "float32"), B: T.Buffer((5120, 5120), "float32"), compute: T.Buffer((5120, 5120), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        A_shared = T.alloc_buffer((5120, 5120), scope="shared")
        B_shared = T.alloc_buffer((5120, 5120), scope="shared")
        A_shared_local = T.alloc_buffer((5120, 5120), scope="local")
        B_shared_local = T.alloc_buffer((5120, 5120), scope="local")
        compute_local = T.alloc_buffer((5120, 5120), scope="local")
        for ax0_0_ax1_0_fused in T.thread_binding(1600, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(8, thread="vthread.x"):
                for ax1_1 in T.thread_binding(4, thread="vthread.y"):
                    for ax0_2_ax1_2_fused in T.thread_binding(512, thread="threadIdx.x"):
                        for ax0_0 in range(640):
                            for ax0_ax1_fused_0_0 in T.unroll(2):
                                for ax0_ax1_fused_0_1 in T.thread_binding(512, thread="threadIdx.x"):
                                    for ax0_ax1_fused_1 in T.vectorized(1):
                                        with T.block("A_shared"):
                                            v0 = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) // 8)
                                            v1 = T.axis.spatial(5120, ax0_0 * 8 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) % 8)
                                            T.reads(A[v0, v1])
                                            T.writes(A_shared[v0, v1])
                                            A_shared[v0, v1] = A[v0, v1]
                            for ax0_ax1_fused_0_0 in T.unroll(2):
                                for ax0_ax1_fused_0_1 in T.thread_binding(512, thread="threadIdx.x"):
                                    for ax0_ax1_fused_1 in T.vectorized(1):
                                        with T.block("B_shared"):
                                            v0 = T.axis.spatial(5120, ax0_0 * 8 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) // 128)
                                            v1 = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) % 128)
                                            T.reads(B[v0, v1])
                                            T.writes(B_shared[v0, v1])
                                            B_shared[v0, v1] = B[v0, v1]
                            for ax0_1_1, ax0_2 in T.grid(8, 1):
                                with T.block("A_shared_local"):
                                    v0 = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + ax0_1 * 16 + ax0_2_ax1_2_fused // 32)
                                    v1 = T.axis.spatial(5120, ax0_0 * 8 + ax0_1_1)
                                    T.reads(A_shared[v0, v1])
                                    T.writes(A_shared_local[v0, v1])
                                    A_shared_local[v0, v1] = A_shared[v0, v1]
                                with T.block("B_shared_local"):
                                    v0 = T.axis.spatial(5120, ax0_0 * 8 + ax0_1_1)
                                    v1 = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + ax1_1 * 32 + ax0_2_ax1_2_fused % 32)
                                    T.reads(B_shared[v0, v1])
                                    T.writes(B_shared_local[v0, v1])
                                    B_shared_local[v0, v1] = B_shared[v0, v1]
                                with T.block("compute"):
                                    v_y = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + ax0_1 * 16 + ax0_2_ax1_2_fused // 32)
                                    v_x = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + ax1_1 * 32 + ax0_2_ax1_2_fused % 32)
                                    v_k = T.axis.reduce(5120, ax0_0 * 8 + ax0_1_1 + ax0_2)
                                    T.reads(A_shared_local[v_y, v_k], B_shared_local[v_k, v_x])
                                    T.writes(compute_local[v_y, v_x])
                                    with T.init():
                                        compute_local[v_y, v_x] = T.float32(0.0)
                                    compute_local[v_y, v_x] = compute_local[v_y, v_x] + A_shared_local[v_y, v_k] * B_shared_local[v_k, v_x]
                        for ax0_3, ax1_3 in T.grid(1, 1):
                            with T.block("compute_local"):
                                v0 = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + ax0_1 * 16 + ax0_2_ax1_2_fused // 32 + ax0_3)
                                v1 = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + ax1_1 * 32 + ax0_2_ax1_2_fused % 32 + ax1_3)
                                T.reads(compute_local[v0, v1])
                                T.writes(compute[v0, v1])
                                compute[v0, v1] = compute_local[v0, v1]

The complete generated CUDA kernel is quite large but shows the initialization pattern clearly:

#include <cuda_runtime.h>
#include <stdio.h>
#include <stdlib.h>
#include "cu_helper.h"
#include <cuda_fp16.h>
#include <mma.h>
#include <string>

//full_dimensions: [5120, 5120, 5120]
#include <cuda.h>

#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
     (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif
#include <cstdint>
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
extern "C" __global__ void __launch_bounds__(512) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute);
extern "C" __global__ void __launch_bounds__(512) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute) {
  float compute_local[32];
  __shared__ float A_shared[1024];
  __shared__ float B_shared[1024];
  float A_shared_local[8];
  float B_shared_local[4];
  for (int ax0_0 = 0; ax0_0 < 640; ++ax0_0) {
    __syncthreads();
    A_shared[((int)threadIdx.x)] = A[(((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 3) * 5120)) + (ax0_0 * 8)) + (((int)threadIdx.x) & 7))];
    A_shared[(((int)threadIdx.x) + 512)] = A[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 3) * 5120)) + (ax0_0 * 8)) + (((int)threadIdx.x) & 7)) + 327680)];
    B_shared[((int)threadIdx.x)] = B[((((ax0_0 * 40960) + ((((int)threadIdx.x) >> 7) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 127))];
    B_shared[(((int)threadIdx.x) + 512)] = B[(((((ax0_0 * 40960) + ((((int)threadIdx.x) >> 7) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 127)) + 20480)];
    __syncthreads();
    for (int ax0_1 = 0; ax0_1 < 8; ++ax0_1) {
      A_shared_local[0] = A_shared[(((((int)threadIdx.x) >> 5) * 8) + ax0_1)];
      A_shared_local[1] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 128)];
      A_shared_local[2] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 256)];
      A_shared_local[3] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 384)];
      A_shared_local[4] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 512)];
      A_shared_local[5] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 640)];
      A_shared_local[6] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 768)];
      A_shared_local[7] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 896)];
      B_shared_local[0] = B_shared[((ax0_1 * 128) + (((int)threadIdx.x) & 31))];
      B_shared_local[1] = B_shared[(((ax0_1 * 128) + (((int)threadIdx.x) & 31)) + 32)];
      B_shared_local[2] = B_shared[(((ax0_1 * 128) + (((int)threadIdx.x) & 31)) + 64)];
      B_shared_local[3] = B_shared[(((ax0_1 * 128) + (((int)threadIdx.x) & 31)) + 96)];
      if (((ax0_0 * 8) + ax0_1) == 0) {
        compute_local[0] = 0.000000e+00f;
        compute_local[4] = 0.000000e+00f;
        compute_local[8] = 0.000000e+00f;
        compute_local[12] = 0.000000e+00f;
        compute_local[16] = 0.000000e+00f;
        compute_local[20] = 0.000000e+00f;
        compute_local[24] = 0.000000e+00f;
        compute_local[28] = 0.000000e+00f;
        compute_local[1] = 0.000000e+00f;
        compute_local[5] = 0.000000e+00f;
        compute_local[9] = 0.000000e+00f;
        compute_local[13] = 0.000000e+00f;
        compute_local[17] = 0.000000e+00f;
        compute_local[21] = 0.000000e+00f;
        compute_local[25] = 0.000000e+00f;
        compute_local[29] = 0.000000e+00f;
        compute_local[2] = 0.000000e+00f;
        compute_local[6] = 0.000000e+00f;
        compute_local[10] = 0.000000e+00f;
        compute_local[14] = 0.000000e+00f;
        compute_local[18] = 0.000000e+00f;
        compute_local[22] = 0.000000e+00f;
        compute_local[26] = 0.000000e+00f;
        compute_local[30] = 0.000000e+00f;
        compute_local[3] = 0.000000e+00f;
        compute_local[7] = 0.000000e+00f;
        compute_local[11] = 0.000000e+00f;
        compute_local[15] = 0.000000e+00f;
        compute_local[19] = 0.000000e+00f;
        compute_local[23] = 0.000000e+00f;
        compute_local[27] = 0.000000e+00f;
        compute_local[31] = 0.000000e+00f;
      }
      compute_local[0] = (compute_local[0] + (A_shared_local[0] * B_shared_local[0]));
      compute_local[4] = (compute_local[4] + (A_shared_local[1] * B_shared_local[0]));
      compute_local[8] = (compute_local[8] + (A_shared_local[2] * B_shared_local[0]));
      compute_local[12] = (compute_local[12] + (A_shared_local[3] * B_shared_local[0]));
      compute_local[16] = (compute_local[16] + (A_shared_local[4] * B_shared_local[0]));
      compute_local[20] = (compute_local[20] + (A_shared_local[5] * B_shared_local[0]));
      compute_local[24] = (compute_local[24] + (A_shared_local[6] * B_shared_local[0]));
      compute_local[28] = (compute_local[28] + (A_shared_local[7] * B_shared_local[0]));
      compute_local[1] = (compute_local[1] + (A_shared_local[0] * B_shared_local[1]));
      compute_local[5] = (compute_local[5] + (A_shared_local[1] * B_shared_local[1]));
      compute_local[9] = (compute_local[9] + (A_shared_local[2] * B_shared_local[1]));
      compute_local[13] = (compute_local[13] + (A_shared_local[3] * B_shared_local[1]));
      compute_local[17] = (compute_local[17] + (A_shared_local[4] * B_shared_local[1]));
      compute_local[21] = (compute_local[21] + (A_shared_local[5] * B_shared_local[1]));
      compute_local[25] = (compute_local[25] + (A_shared_local[6] * B_shared_local[1]));
      compute_local[29] = (compute_local[29] + (A_shared_local[7] * B_shared_local[1]));
      compute_local[2] = (compute_local[2] + (A_shared_local[0] * B_shared_local[2]));
      compute_local[6] = (compute_local[6] + (A_shared_local[1] * B_shared_local[2]));
      compute_local[10] = (compute_local[10] + (A_shared_local[2] * B_shared_local[2]));
      compute_local[14] = (compute_local[14] + (A_shared_local[3] * B_shared_local[2]));
      compute_local[18] = (compute_local[18] + (A_shared_local[4] * B_shared_local[2]));
      compute_local[22] = (compute_local[22] + (A_shared_local[5] * B_shared_local[2]));
      compute_local[26] = (compute_local[26] + (A_shared_local[6] * B_shared_local[2]));
      compute_local[30] = (compute_local[30] + (A_shared_local[7] * B_shared_local[2]));
      compute_local[3] = (compute_local[3] + (A_shared_local[0] * B_shared_local[3]));
      compute_local[7] = (compute_local[7] + (A_shared_local[1] * B_shared_local[3]));
      compute_local[11] = (compute_local[11] + (A_shared_local[2] * B_shared_local[3]));
      compute_local[15] = (compute_local[15] + (A_shared_local[3] * B_shared_local[3]));
      compute_local[19] = (compute_local[19] + (A_shared_local[4] * B_shared_local[3]));
      compute_local[23] = (compute_local[23] + (A_shared_local[5] * B_shared_local[3]));
      compute_local[27] = (compute_local[27] + (A_shared_local[6] * B_shared_local[3]));
      compute_local[31] = (compute_local[31] + (A_shared_local[7] * B_shared_local[3]));
    }
  }
  compute[(((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31))] = compute_local[0];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 81920)] = compute_local[4];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163840)] = compute_local[8];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245760)] = compute_local[12];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327680)] = compute_local[16];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409600)] = compute_local[20];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491520)] = compute_local[24];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573440)] = compute_local[28];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 32)] = compute_local[1];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 81952)] = compute_local[5];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163872)] = compute_local[9];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245792)] = compute_local[13];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327712)] = compute_local[17];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409632)] = compute_local[21];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491552)] = compute_local[25];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573472)] = compute_local[29];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 64)] = compute_local[2];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 81984)] = compute_local[6];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163904)] = compute_local[10];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245824)] = compute_local[14];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327744)] = compute_local[18];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409664)] = compute_local[22];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491584)] = compute_local[26];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573504)] = compute_local[30];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 96)] = compute_local[3];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 82016)] = compute_local[7];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163936)] = compute_local[11];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245856)] = compute_local[15];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327776)] = compute_local[19];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409696)] = compute_local[23];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491616)] = compute_local[27];
  compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573536)] = compute_local[31];
}

This issue has been solved.