[Questions] TVMScript sch.tensorize error when codegen CUDA code using Tensor Core

I try to use TVMScript to write a [M x K x N] matmul to generate CUDA code using tensor core and met this trouble.

When i set [M, K, N] = [1024, 1024, 1024]. It seems the script can generate correct code. But when i change K ot 2024(the k0 loop change from 16 to 32 at same time), I got “Error message: ScheduleError: The bindings of the inner block tir.Block#0 can not be blockized by the loops starting at tir.For#1.” from "sch.tensorize(sch.get_loops(Y_local)[-2], “wmma_load_b”) "

Here are the specific error info and my code. Thanks for your help.

for ax0_0 in range(2):
    for ax1_0 in range(2):
        # tir.For#1
        for ax0_1 in range(16):
        ^^^^^^^^^^^^^^^^^^^^^^^
            for ax1_1 in range(16):
            ^^^^^^^^^^^^^^^^^^^^^^^
                # tir.Block#0
                with T.block("Y_shared_wmma.matrix_b"):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    v0 = T.axis.spatial(2048, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax0_0 * 16 + ax0_1)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1_0 * 16 + ax1_1)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    T.where(k_0_0 * 64 + k_0_1 * 32 + (ax1_0 * 16 + ax1_1) < 1024)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    T.reads(Y_shared[v0, v1])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
                    T.writes(Y_shared_wmma_matrix_b[v0, v1])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    Y_shared_wmma_matrix_b[v0, v1] = Y_shared[v0, v1]
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                          
import tvm

from tvm.script import tir as T

from tvm import tir

import numpy as np

M = 1024

K = 1024

N = 1024

@tvm.script.ir_module

class MatmulModule:

    @T.prim_func

    def main(

        X: T.Buffer((M, K), "float16", align=128),

        Y: T.Buffer((K, N), "float16", align=128),

        Z: T.Buffer((M, N), "float32", align=64),

    ) -> None:

        T.func_attr({"global_symbol": "main", "tir.noalias": True})

        for i, j, k in T.grid(M, N, K):

            with T.block("matmul"):

                vi, vj, vk = T.axis.remap("SSR", [i, j, k])

                with T.init():

                    Z[vi, vj] = T.float32(0)

                Z[vi, vj] += T.cast(X[vi, vk], "float32") * T.cast(Y[vj, vk], "float32")

@T.prim_func

def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:

    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")

    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    with T.block("root"):

        T.reads(A[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        for i, j in T.grid(16, 16):

            with T.block("load"):

                vii, vjj = T.axis.remap("SS", [i, j])

                C[vii, vjj] = A[vii, vjj]

@T.prim_func

def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:

    s1 = T.int32()

    s0 = T.int32()

    A = T.match_buffer(

        a,

        (16, 16),

        "float16",

        align=128,

        offset_factor=16,

        scope="shared",

        strides=[s1, s0],

    )

    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    with T.block("root"):

        T.reads(A[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        T.evaluate(

            T.tvm_load_matrix_sync(

                C.data,

                16,

                16,

                16,

                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),

                A.access_ptr("r"),

                s1,

                "row_major",

                dtype="handle",

            )

        )

@T.prim_func

def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:

    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")

    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

    with T.block("root"):

        T.reads(A[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        for i, j in T.grid(16, 16):

            with T.block("load"):

                vii, vjj = T.axis.remap("SS", [i, j])

                C[vii, vjj] = A[vii, vjj]

@T.prim_func

def wmma_load_b_impl(a: T.handle, c: T.handle) -> None:

    s1 = T.int32()

    s0 = T.int32()

    A = T.match_buffer(

        a,

        (16, 16),

        "float16",

        align=128,

        offset_factor=16,

        scope="shared",

        strides=[s1, s0],

    )

    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

    with T.block("root"):

        T.reads(A[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        T.evaluate(

            T.tvm_load_matrix_sync(

                C.data,

                16,

                16,

                16,

                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),

                A.access_ptr("r"),

                s1,

                "col_major",

                dtype="handle",

            )

        )

@T.prim_func

def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:

    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

    C = T.match_buffer(

        c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"

    )

    with T.block("root"):

        T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        for i, j, k in T.grid(16, 16, 16):

            with T.block(""):

                vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])

                C[vii, vjj] += T.cast(A[vii, vkk], "float32") * T.cast(B[vjj, vkk], "float32")

@T.prim_func

def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:

    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

    C = T.match_buffer(

        c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"

    )

    with T.block("root"):

        T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        T.evaluate(

            T.tvm_mma_sync(

                C.data,

                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),

                A.data,

                A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),

                B.data,

                B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16),

                C.data,

                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),

                dtype="handle",

            )

        )

@T.prim_func

def wmma_fill_desc(c: T.handle) -> None:

    C = T.match_buffer(

        c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"

    )

    with T.block("root"):

        T.reads()

        T.writes(C[0:16, 0:16])

        for i, j in T.grid(16, 16):

            with T.block("init"):

                vii, vjj = T.axis.remap("SS", [i, j])

                C[vii, vjj] = T.float32(0)

@T.prim_func

def wmma_fill_impl(c: T.handle) -> None:

    C = T.match_buffer(

        c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"

    )

    with T.block("root"):

        T.reads()

        T.writes(C[0:16, 0:16])

        T.evaluate(

            T.tvm_fill_fragment(

                C.data,

                16,

                16,

                16,

                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),

                T.float32(0),

                dtype="handle",

            )

        )

@T.prim_func

def wmma_store_desc(a: T.handle, c: T.handle) -> None:

    A = T.match_buffer(

        a, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"

    )

    C = T.match_buffer(c, (16, 16), "float32", align=64, offset_factor=16, scope="global")

    with T.block("root"):

        T.reads(A[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        for i, j in T.grid(16, 16):

            with T.block("store"):

                vii, vjj = T.axis.remap("SS", [i, j])

                C[vii, vjj] = A[vii, vjj]

@T.prim_func

def wmma_store_impl(a: T.handle, c: T.handle) -> None:

    s1 = T.int32()

    s0 = T.int32()

    A = T.match_buffer(

        a, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"

    )

    C = T.match_buffer(

        c,

        (16, 16),

        "float32",

        align=64,

        offset_factor=16,

        scope="global",

        strides=[s1, s0],

    )

    with T.block("root"):

        T.reads(A[0:16, 0:16])

        T.writes(C[0:16, 0:16])

        T.evaluate(

            T.tvm_store_matrix_sync(

                A.data,

                16,

                16,

                16,

                A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),

                C.access_ptr("w"),

                s1,

                "row_major",

                dtype="handle",

            )

        )

try:

    tir.TensorIntrin.register("wmma_load_a", wmma_load_a_desc, wmma_load_a_impl)

    tir.TensorIntrin.register("wmma_load_b", wmma_load_b_desc, wmma_load_b_impl)

    tir.TensorIntrin.register("wmma_sync", wmma_sync_desc, wmma_sync_impl)

    tir.TensorIntrin.register("wmma_fill", wmma_fill_desc, wmma_fill_impl)

    tir.TensorIntrin.register("wmma_store", wmma_store_desc, wmma_store_impl)

except ValueError:

    pass

   

sch = tir.Schedule(MatmulModule)

block = sch.get_block("matmul")

i, j, k = sch.get_loops(block)

i, ii = sch.split(i, factors=[None, 16])

j, ji = sch.split(j, factors=[None, 16])

k, ki = sch.split(k, factors=[None, 16])

sch.reorder(i, j, k, ii, ji, ki)

wmma_sync = sch.blockize(ii)

i0, i1, i2 = sch.split(i, factors=[8, 4, 2])

j0, j1, j2 = sch.split(j, factors=[8, 4, 2])

k0, k1, k2 = sch.split(k, factors=[16, 2, 2])

sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2, k2)

bx = sch.fuse(i0, j0)

sch.bind(bx, "blockIdx.x")

ty = sch.fuse(i1, j1)

sch.bind(ty, "threadIdx.y")

X_shared = sch.cache_read(wmma_sync, read_buffer_index=0, storage_scope="shared")

Y_shared = sch.cache_read(wmma_sync, read_buffer_index=1, storage_scope="shared")

def schedule_shared(block):

    sch.compute_at(block, k0)

    x, y = sch.get_loops(block)[-2:]

    fused = sch.fuse(x, y)

    x0, x1, x2, x3 = sch.split(fused, factors=[None, 16, 32, 8])

    sch.bind(x1, "threadIdx.y")

    sch.bind(x2, "threadIdx.x")

    sch.vectorize(x3)

schedule_shared(X_shared)

schedule_shared(Y_shared)

X_local = sch.cache_read(wmma_sync, 0, storage_scope="wmma.matrix_a")

Y_local = sch.cache_read(wmma_sync, 1, storage_scope="wmma.matrix_b")

sch.compute_at(X_local, k1)

sch.compute_at(Y_local, k1)

write_back_block = sch.cache_write(wmma_sync, 0, storage_scope="wmma.accumulator")

sch.reverse_compute_at(write_back_block, ty)

def schedule_copy(block):

    x, y = sch.get_loops(block)[-2:]

    x0, x1 = sch.split(x, factors=[None, 16])

    y0, y1 = sch.split(y, factors=[None, 16])

    sch.reorder(x0, y0, x1, y1)

schedule_copy(X_local)

schedule_copy(Y_local)

schedule_copy(write_back_block)

init = sch.decompose_reduction(wmma_sync, k0)

sch.tensorize(sch.get_loops(X_local)[-2], "wmma_load_a")

sch.tensorize(sch.get_loops(Y_local)[-2], "wmma_load_b")

sch.tensorize(init, "wmma_fill")

sch.tensorize(wmma_sync, "wmma_sync")

sch.tensorize(sch.get_loops(write_back_block)[-2], "wmma_store")

func = tvm.build(sch.mod, target="cuda")

dev_module = func.imported_modules[0]

print("-----Final IR-----")

print(dev_module)

print("-----GPU code-----")

print(dev_module.get_source())
"""

There is a T.where in the block Y_shared_wmma.matrix_b, however, the intrinsic of load matrix does not support predicate. So it cannot tensorize

1 Like

Thanks for your replay.

T.read, T.write, and T.where are not defined by me, so they are added by the optimation path in TVM and help TVM generate code. Is that so? Can you help me to explain what information T.where contains and how to solve this problem?

The T.where appear after schedule_shared(Y_shared). And I found the value [v0, v1] may bind to the wrong asis. The k_0_0 should be related to matrix_b’s v0. Is it a mistake here?

with T.block("Y_shared"):
    v0 = T.axis.spatial(2048,i_0_0_j_0_0_fused % 8 * 128+ (ax0_ax1_fused_0 * 4096+ ax0_ax1_fused_1 * 256+ ax0_ax1_fused_2 * 8+ ax0_ax1_fused_3)// 64,
    )
    v1 = T.axis.spatial(1024,k_0_0 * 64+ (ax0_ax1_fused_0 * 4096+ ax0_ax1_fused_1 * 256+ ax0_ax1_fused_2 * 8+ ax0_ax1_fused_3)% 64,
    )
    T.where(k_0_0 * 64+ (( (ax0_ax1_fused_0 * 16+ ax0_ax1_fused_1 ) * 32 + ax0_ax1_fused_2)* 8+ ax0_ax1_fused_3)% 64< 1024
    )

Belowing is the scheduling to reproduce this situation.


sch = tir.Schedule(MatmulModule)
block = sch.get_block("matmul")
i, j, k = sch.get_loops(block)


i, ii = sch.split(i, factors=[None, 16])
j, ji = sch.split(j, factors=[None, 16])
k, ki = sch.split(k, factors=[None, 16])
sch.reorder(i, j, k, ii, ji, ki)
wmma_sync = sch.blockize(ii)

i0, i1, i2 = sch.split(i, factors=[8, 4, 2])
j0, j1, j2 = sch.split(j, factors=[8, 4, 2])
k0, k1, k2 = sch.split(k, factors=[None, 2, 2])
sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2, k2) 
bx = sch.fuse(i0, j0) 
sch.bind(bx, "blockIdx.x")
ty = sch.fuse(i1, j1)
sch.bind(ty, "threadIdx.y")

X_shared = sch.cache_read(wmma_sync, read_buffer_index=0, storage_scope="shared")
Y_shared = sch.cache_read(wmma_sync, read_buffer_index=1, storage_scope="shared")

def schedule_shared(block):
    sch.compute_at(block, k0)
    x, y = sch.get_loops(block)[-2:] 
    fused = sch.fuse(x, y)
    x0, x1, x2, x3 = sch.split(fused, factors=[None, 16, 32, 8])
    sch.bind(x1, "threadIdx.y")
    sch.bind(x2, "threadIdx.x") 
    sch.vectorize(x3)

schedule_shared(X_shared) 
schedule_shared(Y_shared)
sch.mod.show()

T.where means that the computation is with a predicate, and only do the computation when the expression is true. In other words, it’s equal to

if k_0_0 * 64 + k_0_1 * 32 + (ax1_0 * 16 + ax1_1) < 1024:
    Y_shared_wmma_matrix_b[v0, v1] = Y_shared[v0, v1]

It’s usually caused by in-complete splitting.

I just find the mistake in the original ir_module. The vk and vj in Y are reversed. Although there is still a new issue “block doesn’t match the tensor intrin”. It seems much more reasonable. Thanks for your kind again.