[TensorIR] Align Warp Memory for ROCm RDNA3 Matrix Core

Hi all, I’m currently working with the ROCm RDNA3 matrix core, and I’ve encountered a unique challenge related to warp memory alignment. Unlike the CDNA matrix core and NVIDIA’s TensorCore, the RDNA3 architecture uses 32 threads for a 16x16x16 computation, with each thread handling 16 elements. (so weird design).

This architecture presents an awkward situation when tensorizing in TensorIR. To illustrate, consider a WMMA fill operation. The raw IR looks something like this:

C_warp = T.alloc_buffer([16, 16, 16, 16], dtype="float16", scope="warp")
for ii_2_init, jj_2_init, i_init, j_init in T.grid(2, 1, 16, 16):
    with T.block("B_init"):
        vii = T.axis.spatial(16, ii_0 * 2 + ii_1 * 2 + ii_2_init)
        vjj = T.axis.spatial(16, jj_2_init + jj_0 + jj_1)
        vi, vj = T.axis.remap("SS", [i_init, j_init])
        T.reads()
        T.writes(C_warp[vii, vjj, vi, vj])
        C_warp[vii, vjj, vi, vj] = T.float32(0)

After tensorize for rdna3 wmma_fill, the ir will be:

C_warp = T.alloc_buffer([16, 16, 16, 16], dtype="float16", scope="warp")
for ii_2_init, jj_2_init in T.grid(2, 1):
    with T.block("B_init_o"):
        vii = T.axis.spatial(16, ii_0 * 2 + ii_1 * 2 + ii_2_init)
        vjj = T.axis.spatial(16, jj_0 + jj_1 + jj_2_init)
        vi_o = T.axis.spatial(1, 0)
        vj_o = T.axis.spatial(1, 0)
        T.reads()
        T.writes(C_warp[vii, vjj, 0 : 16, 0 : 16])
        C_warp_1 = T.match_buffer(C_warp[vii, vjj, 0 : 32, 0 : 16], [32, 16], dtype="float16", scope="warp", offset_factor=1)
        for tx in T.thread_binding(32, thread="threadIdx.x"):
            for local_id in T.serial(16):
                C_warp_1[tx, local_id] = T.float16(0)

The challenge arises with buffer size alignment. The original buffer is [16, 16], but for RDNA3, it should ideally be [32, 16]. This discrepancy can lead to code generation issues.

And the right ir I customized to make a right codegen can be:

C_warp = T.alloc_buffer([16, 16, 32, 16], dtype="float16", scope="warp")
for ii_2_init, jj_2_init in T.grid(2, 1):
    with T.block("B_init_o"):
        vii = T.axis.spatial(16, ii_0 * 2 + ii_1 * 2 + ii_2_init)
        vjj = T.axis.spatial(16, jj_0 + jj_1 + jj_2_init)
        vi_o = T.axis.spatial(1, 0)
        vj_o = T.axis.spatial(1, 0)
        T.reads()
        T.writes(C_warp[vii, vjj, 0 : 32, 0 : 16])
        C_warp_1 = T.match_buffer(C_warp[vii, vjj, 0 : 32, 0 : 16], [32, 16], dtype="float16", scope="warp", offset_factor=1)
        for tx in T.thread_binding(32, thread="threadIdx.x"):
            for local_id in T.serial(16):
                C_warp_1[tx, local_id] = T.float16(0)

I’m going to write a align_warp_memory pass to align the buffer alloc with its inner blocks, but I also wonder is there a more efficient or smooth way to resolve it?

2 Likes