I try to use TVMScript to write a [M x K x N] matmul to generate CUDA code using tensor core and met this trouble.
When i set [M, K, N] = [1024, 1024, 1024]. It seems the script can generate correct code. But when i change K ot 2024(the k0 loop change from 16 to 32 at same time), I got “Error message: ScheduleError: The bindings of the inner block tir.Block#0 can not be blockized by the loops starting at tir.For#1.” from "sch.tensorize(sch.get_loops(Y_local)[-2], “wmma_load_b”) "
Here are the specific error info and my code. Thanks for your help.
for ax0_0 in range(2):
for ax1_0 in range(2):
# tir.For#1
for ax0_1 in range(16):
^^^^^^^^^^^^^^^^^^^^^^^
for ax1_1 in range(16):
^^^^^^^^^^^^^^^^^^^^^^^
# tir.Block#0
with T.block("Y_shared_wmma.matrix_b"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(2048, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax0_0 * 16 + ax0_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1_0 * 16 + ax1_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.where(k_0_0 * 64 + k_0_1 * 32 + (ax1_0 * 16 + ax1_1) < 1024)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(Y_shared[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(Y_shared_wmma_matrix_b[v0, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Y_shared_wmma_matrix_b[v0, v1] = Y_shared[v0, v1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import tvm
from tvm.script import tir as T
from tvm import tir
import numpy as np
M = 1024
K = 1024
N = 1024
@tvm.script.ir_module
class MatmulModule:
@T.prim_func
def main(
X: T.Buffer((M, K), "float16", align=128),
Y: T.Buffer((K, N), "float16", align=128),
Z: T.Buffer((M, N), "float32", align=64),
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i, j, k in T.grid(M, N, K):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
Z[vi, vj] = T.float32(0)
Z[vi, vj] += T.cast(X[vi, vk], "float32") * T.cast(Y[vj, vk], "float32")
@T.prim_func
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
for i, j in T.grid(16, 16):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = A[vii, vjj]
@T.prim_func
def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:
s1 = T.int32()
s0 = T.int32()
A = T.match_buffer(
a,
(16, 16),
"float16",
align=128,
offset_factor=16,
scope="shared",
strides=[s1, s0],
)
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.access_ptr("r"),
s1,
"row_major",
dtype="handle",
)
)
@T.prim_func
def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
for i, j in T.grid(16, 16):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = A[vii, vjj]
@T.prim_func
def wmma_load_b_impl(a: T.handle, c: T.handle) -> None:
s1 = T.int32()
s0 = T.int32()
A = T.match_buffer(
a,
(16, 16),
"float16",
align=128,
offset_factor=16,
scope="shared",
strides=[s1, s0],
)
C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.access_ptr("r"),
s1,
"col_major",
dtype="handle",
)
)
@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
C = T.match_buffer(
c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
T.writes(C[0:16, 0:16])
for i, j, k in T.grid(16, 16, 16):
with T.block(""):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] += T.cast(A[vii, vkk], "float32") * T.cast(B[vjj, vkk], "float32")
@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
C = T.match_buffer(
c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_mma_sync(
C.data,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
A.data,
A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
B.data,
B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16),
C.data,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
dtype="handle",
)
)
@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
C = T.match_buffer(
c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads()
T.writes(C[0:16, 0:16])
for i, j in T.grid(16, 16):
with T.block("init"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = T.float32(0)
@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
C = T.match_buffer(
c, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
T.reads()
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_fill_fragment(
C.data,
16,
16,
16,
C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
T.float32(0),
dtype="handle",
)
)
@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (16, 16), "float32", align=64, offset_factor=16, scope="global")
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
for i, j in T.grid(16, 16):
with T.block("store"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = A[vii, vjj]
@T.prim_func
def wmma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.int32()
s0 = T.int32()
A = T.match_buffer(
a, (16, 16), "float32", align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(
c,
(16, 16),
"float32",
align=64,
offset_factor=16,
scope="global",
strides=[s1, s0],
)
with T.block("root"):
T.reads(A[0:16, 0:16])
T.writes(C[0:16, 0:16])
T.evaluate(
T.tvm_store_matrix_sync(
A.data,
16,
16,
16,
A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
C.access_ptr("w"),
s1,
"row_major",
dtype="handle",
)
)
try:
tir.TensorIntrin.register("wmma_load_a", wmma_load_a_desc, wmma_load_a_impl)
tir.TensorIntrin.register("wmma_load_b", wmma_load_b_desc, wmma_load_b_impl)
tir.TensorIntrin.register("wmma_sync", wmma_sync_desc, wmma_sync_impl)
tir.TensorIntrin.register("wmma_fill", wmma_fill_desc, wmma_fill_impl)
tir.TensorIntrin.register("wmma_store", wmma_store_desc, wmma_store_impl)
except ValueError:
pass
sch = tir.Schedule(MatmulModule)
block = sch.get_block("matmul")
i, j, k = sch.get_loops(block)
i, ii = sch.split(i, factors=[None, 16])
j, ji = sch.split(j, factors=[None, 16])
k, ki = sch.split(k, factors=[None, 16])
sch.reorder(i, j, k, ii, ji, ki)
wmma_sync = sch.blockize(ii)
i0, i1, i2 = sch.split(i, factors=[8, 4, 2])
j0, j1, j2 = sch.split(j, factors=[8, 4, 2])
k0, k1, k2 = sch.split(k, factors=[16, 2, 2])
sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2, k2)
bx = sch.fuse(i0, j0)
sch.bind(bx, "blockIdx.x")
ty = sch.fuse(i1, j1)
sch.bind(ty, "threadIdx.y")
X_shared = sch.cache_read(wmma_sync, read_buffer_index=0, storage_scope="shared")
Y_shared = sch.cache_read(wmma_sync, read_buffer_index=1, storage_scope="shared")
def schedule_shared(block):
sch.compute_at(block, k0)
x, y = sch.get_loops(block)[-2:]
fused = sch.fuse(x, y)
x0, x1, x2, x3 = sch.split(fused, factors=[None, 16, 32, 8])
sch.bind(x1, "threadIdx.y")
sch.bind(x2, "threadIdx.x")
sch.vectorize(x3)
schedule_shared(X_shared)
schedule_shared(Y_shared)
X_local = sch.cache_read(wmma_sync, 0, storage_scope="wmma.matrix_a")
Y_local = sch.cache_read(wmma_sync, 1, storage_scope="wmma.matrix_b")
sch.compute_at(X_local, k1)
sch.compute_at(Y_local, k1)
write_back_block = sch.cache_write(wmma_sync, 0, storage_scope="wmma.accumulator")
sch.reverse_compute_at(write_back_block, ty)
def schedule_copy(block):
x, y = sch.get_loops(block)[-2:]
x0, x1 = sch.split(x, factors=[None, 16])
y0, y1 = sch.split(y, factors=[None, 16])
sch.reorder(x0, y0, x1, y1)
schedule_copy(X_local)
schedule_copy(Y_local)
schedule_copy(write_back_block)
init = sch.decompose_reduction(wmma_sync, k0)
sch.tensorize(sch.get_loops(X_local)[-2], "wmma_load_a")
sch.tensorize(sch.get_loops(Y_local)[-2], "wmma_load_b")
sch.tensorize(init, "wmma_fill")
sch.tensorize(wmma_sync, "wmma_sync")
sch.tensorize(sch.get_loops(write_back_block)[-2], "wmma_store")
func = tvm.build(sch.mod, target="cuda")
dev_module = func.imported_modules[0]
print("-----Final IR-----")
print(dev_module)
print("-----GPU code-----")
print(dev_module.get_source())
"""