Hi,
I’m working with an IRModule (see below) that generates CUDA code for matrix multiplication. In the generated code, there’s an initialization block that only executes when ((ax0_0 * 8) + ax0_1) == 0
:
for (int ax0_0 = 0; ax0_0 < 640; ++ax0_0) {
// ...
for (int ax0_1 = 0; ax0_1 < 8; ++ax0_1) {
// ...
if (((ax0_0 * 8) + ax0_1) == 0) {
compute_local[0] = 0.000000e+00f;
compute_local[4] = 0.000000e+00f;
// ... (all other initializations)
}
// ...
}
// ...
}
This initialization could be performed earlier since the condition only evaluates to true when all outer loop variables are zero. The jump instructions compiled from the if statement will cause the thread execution to slow down. In my test, about 1/5 of the kernel execution time is wasted.
I know that decompose_reduction
can help move the initialization block earlier, but it generates two separate kernels. The startup overhead of two kernels (requiring the host to send thread blocks to the SM twice) is greater than that of a single kernel.
Is there a way to optimize this initialization without splitting into multiple kernels?
Here’s the relevant part of the IRModule:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((5120, 5120), "float32"), B: T.Buffer((5120, 5120), "float32"), compute: T.Buffer((5120, 5120), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
A_shared = T.alloc_buffer((5120, 5120), scope="shared")
B_shared = T.alloc_buffer((5120, 5120), scope="shared")
A_shared_local = T.alloc_buffer((5120, 5120), scope="local")
B_shared_local = T.alloc_buffer((5120, 5120), scope="local")
compute_local = T.alloc_buffer((5120, 5120), scope="local")
for ax0_0_ax1_0_fused in T.thread_binding(1600, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(8, thread="vthread.x"):
for ax1_1 in T.thread_binding(4, thread="vthread.y"):
for ax0_2_ax1_2_fused in T.thread_binding(512, thread="threadIdx.x"):
for ax0_0 in range(640):
for ax0_ax1_fused_0_0 in T.unroll(2):
for ax0_ax1_fused_0_1 in T.thread_binding(512, thread="threadIdx.x"):
for ax0_ax1_fused_1 in T.vectorized(1):
with T.block("A_shared"):
v0 = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) // 8)
v1 = T.axis.spatial(5120, ax0_0 * 8 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) % 8)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0_0 in T.unroll(2):
for ax0_ax1_fused_0_1 in T.thread_binding(512, thread="threadIdx.x"):
for ax0_ax1_fused_1 in T.vectorized(1):
with T.block("B_shared"):
v0 = T.axis.spatial(5120, ax0_0 * 8 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) // 128)
v1 = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + (ax0_ax1_fused_0_0 * 512 + ax0_ax1_fused_0_1 + ax0_ax1_fused_1) % 128)
T.reads(B[v0, v1])
T.writes(B_shared[v0, v1])
B_shared[v0, v1] = B[v0, v1]
for ax0_1_1, ax0_2 in T.grid(8, 1):
with T.block("A_shared_local"):
v0 = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + ax0_1 * 16 + ax0_2_ax1_2_fused // 32)
v1 = T.axis.spatial(5120, ax0_0 * 8 + ax0_1_1)
T.reads(A_shared[v0, v1])
T.writes(A_shared_local[v0, v1])
A_shared_local[v0, v1] = A_shared[v0, v1]
with T.block("B_shared_local"):
v0 = T.axis.spatial(5120, ax0_0 * 8 + ax0_1_1)
v1 = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + ax1_1 * 32 + ax0_2_ax1_2_fused % 32)
T.reads(B_shared[v0, v1])
T.writes(B_shared_local[v0, v1])
B_shared_local[v0, v1] = B_shared[v0, v1]
with T.block("compute"):
v_y = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + ax0_1 * 16 + ax0_2_ax1_2_fused // 32)
v_x = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + ax1_1 * 32 + ax0_2_ax1_2_fused % 32)
v_k = T.axis.reduce(5120, ax0_0 * 8 + ax0_1_1 + ax0_2)
T.reads(A_shared_local[v_y, v_k], B_shared_local[v_k, v_x])
T.writes(compute_local[v_y, v_x])
with T.init():
compute_local[v_y, v_x] = T.float32(0.0)
compute_local[v_y, v_x] = compute_local[v_y, v_x] + A_shared_local[v_y, v_k] * B_shared_local[v_k, v_x]
for ax0_3, ax1_3 in T.grid(1, 1):
with T.block("compute_local"):
v0 = T.axis.spatial(5120, ax0_0_ax1_0_fused // 40 * 128 + ax0_1 * 16 + ax0_2_ax1_2_fused // 32 + ax0_3)
v1 = T.axis.spatial(5120, ax0_0_ax1_0_fused % 40 * 128 + ax1_1 * 32 + ax0_2_ax1_2_fused % 32 + ax1_3)
T.reads(compute_local[v0, v1])
T.writes(compute[v0, v1])
compute[v0, v1] = compute_local[v0, v1]
The complete generated CUDA kernel is quite large but shows the initialization pattern clearly:
#include <cuda_runtime.h>
#include <stdio.h>
#include <stdlib.h>
#include "cu_helper.h"
#include <cuda_fp16.h>
#include <mma.h>
#include <string>
//full_dimensions: [5120, 5120, 5120]
#include <cuda.h>
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
(__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif
#include <cstdint>
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
extern "C" __global__ void __launch_bounds__(512) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute);
extern "C" __global__ void __launch_bounds__(512) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ compute) {
float compute_local[32];
__shared__ float A_shared[1024];
__shared__ float B_shared[1024];
float A_shared_local[8];
float B_shared_local[4];
for (int ax0_0 = 0; ax0_0 < 640; ++ax0_0) {
__syncthreads();
A_shared[((int)threadIdx.x)] = A[(((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 3) * 5120)) + (ax0_0 * 8)) + (((int)threadIdx.x) & 7))];
A_shared[(((int)threadIdx.x) + 512)] = A[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 3) * 5120)) + (ax0_0 * 8)) + (((int)threadIdx.x) & 7)) + 327680)];
B_shared[((int)threadIdx.x)] = B[((((ax0_0 * 40960) + ((((int)threadIdx.x) >> 7) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 127))];
B_shared[(((int)threadIdx.x) + 512)] = B[(((((ax0_0 * 40960) + ((((int)threadIdx.x) >> 7) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 127)) + 20480)];
__syncthreads();
for (int ax0_1 = 0; ax0_1 < 8; ++ax0_1) {
A_shared_local[0] = A_shared[(((((int)threadIdx.x) >> 5) * 8) + ax0_1)];
A_shared_local[1] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 128)];
A_shared_local[2] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 256)];
A_shared_local[3] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 384)];
A_shared_local[4] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 512)];
A_shared_local[5] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 640)];
A_shared_local[6] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 768)];
A_shared_local[7] = A_shared[((((((int)threadIdx.x) >> 5) * 8) + ax0_1) + 896)];
B_shared_local[0] = B_shared[((ax0_1 * 128) + (((int)threadIdx.x) & 31))];
B_shared_local[1] = B_shared[(((ax0_1 * 128) + (((int)threadIdx.x) & 31)) + 32)];
B_shared_local[2] = B_shared[(((ax0_1 * 128) + (((int)threadIdx.x) & 31)) + 64)];
B_shared_local[3] = B_shared[(((ax0_1 * 128) + (((int)threadIdx.x) & 31)) + 96)];
if (((ax0_0 * 8) + ax0_1) == 0) {
compute_local[0] = 0.000000e+00f;
compute_local[4] = 0.000000e+00f;
compute_local[8] = 0.000000e+00f;
compute_local[12] = 0.000000e+00f;
compute_local[16] = 0.000000e+00f;
compute_local[20] = 0.000000e+00f;
compute_local[24] = 0.000000e+00f;
compute_local[28] = 0.000000e+00f;
compute_local[1] = 0.000000e+00f;
compute_local[5] = 0.000000e+00f;
compute_local[9] = 0.000000e+00f;
compute_local[13] = 0.000000e+00f;
compute_local[17] = 0.000000e+00f;
compute_local[21] = 0.000000e+00f;
compute_local[25] = 0.000000e+00f;
compute_local[29] = 0.000000e+00f;
compute_local[2] = 0.000000e+00f;
compute_local[6] = 0.000000e+00f;
compute_local[10] = 0.000000e+00f;
compute_local[14] = 0.000000e+00f;
compute_local[18] = 0.000000e+00f;
compute_local[22] = 0.000000e+00f;
compute_local[26] = 0.000000e+00f;
compute_local[30] = 0.000000e+00f;
compute_local[3] = 0.000000e+00f;
compute_local[7] = 0.000000e+00f;
compute_local[11] = 0.000000e+00f;
compute_local[15] = 0.000000e+00f;
compute_local[19] = 0.000000e+00f;
compute_local[23] = 0.000000e+00f;
compute_local[27] = 0.000000e+00f;
compute_local[31] = 0.000000e+00f;
}
compute_local[0] = (compute_local[0] + (A_shared_local[0] * B_shared_local[0]));
compute_local[4] = (compute_local[4] + (A_shared_local[1] * B_shared_local[0]));
compute_local[8] = (compute_local[8] + (A_shared_local[2] * B_shared_local[0]));
compute_local[12] = (compute_local[12] + (A_shared_local[3] * B_shared_local[0]));
compute_local[16] = (compute_local[16] + (A_shared_local[4] * B_shared_local[0]));
compute_local[20] = (compute_local[20] + (A_shared_local[5] * B_shared_local[0]));
compute_local[24] = (compute_local[24] + (A_shared_local[6] * B_shared_local[0]));
compute_local[28] = (compute_local[28] + (A_shared_local[7] * B_shared_local[0]));
compute_local[1] = (compute_local[1] + (A_shared_local[0] * B_shared_local[1]));
compute_local[5] = (compute_local[5] + (A_shared_local[1] * B_shared_local[1]));
compute_local[9] = (compute_local[9] + (A_shared_local[2] * B_shared_local[1]));
compute_local[13] = (compute_local[13] + (A_shared_local[3] * B_shared_local[1]));
compute_local[17] = (compute_local[17] + (A_shared_local[4] * B_shared_local[1]));
compute_local[21] = (compute_local[21] + (A_shared_local[5] * B_shared_local[1]));
compute_local[25] = (compute_local[25] + (A_shared_local[6] * B_shared_local[1]));
compute_local[29] = (compute_local[29] + (A_shared_local[7] * B_shared_local[1]));
compute_local[2] = (compute_local[2] + (A_shared_local[0] * B_shared_local[2]));
compute_local[6] = (compute_local[6] + (A_shared_local[1] * B_shared_local[2]));
compute_local[10] = (compute_local[10] + (A_shared_local[2] * B_shared_local[2]));
compute_local[14] = (compute_local[14] + (A_shared_local[3] * B_shared_local[2]));
compute_local[18] = (compute_local[18] + (A_shared_local[4] * B_shared_local[2]));
compute_local[22] = (compute_local[22] + (A_shared_local[5] * B_shared_local[2]));
compute_local[26] = (compute_local[26] + (A_shared_local[6] * B_shared_local[2]));
compute_local[30] = (compute_local[30] + (A_shared_local[7] * B_shared_local[2]));
compute_local[3] = (compute_local[3] + (A_shared_local[0] * B_shared_local[3]));
compute_local[7] = (compute_local[7] + (A_shared_local[1] * B_shared_local[3]));
compute_local[11] = (compute_local[11] + (A_shared_local[2] * B_shared_local[3]));
compute_local[15] = (compute_local[15] + (A_shared_local[3] * B_shared_local[3]));
compute_local[19] = (compute_local[19] + (A_shared_local[4] * B_shared_local[3]));
compute_local[23] = (compute_local[23] + (A_shared_local[5] * B_shared_local[3]));
compute_local[27] = (compute_local[27] + (A_shared_local[6] * B_shared_local[3]));
compute_local[31] = (compute_local[31] + (A_shared_local[7] * B_shared_local[3]));
}
}
compute[(((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31))] = compute_local[0];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 81920)] = compute_local[4];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163840)] = compute_local[8];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245760)] = compute_local[12];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327680)] = compute_local[16];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409600)] = compute_local[20];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491520)] = compute_local[24];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573440)] = compute_local[28];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 32)] = compute_local[1];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 81952)] = compute_local[5];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163872)] = compute_local[9];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245792)] = compute_local[13];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327712)] = compute_local[17];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409632)] = compute_local[21];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491552)] = compute_local[25];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573472)] = compute_local[29];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 64)] = compute_local[2];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 81984)] = compute_local[6];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163904)] = compute_local[10];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245824)] = compute_local[14];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327744)] = compute_local[18];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409664)] = compute_local[22];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491584)] = compute_local[26];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573504)] = compute_local[30];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 96)] = compute_local[3];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 82016)] = compute_local[7];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 163936)] = compute_local[11];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 245856)] = compute_local[15];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 327776)] = compute_local[19];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 409696)] = compute_local[23];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 491616)] = compute_local[27];
compute[((((((((int)blockIdx.x) / 40) * 655360) + ((((int)threadIdx.x) >> 5) * 5120)) + ((((int)blockIdx.x) % 40) * 128)) + (((int)threadIdx.x) & 31)) + 573536)] = compute_local[31];
}