Confused about tir cache related schedule primitives
Pick up matrix multiplication as example:
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [M, K])
B = T.match_buffer(b, [K, N])
C = T.match_buffer(c, [M, N])
for i, j, k in T.grid(M, K, N):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Simply do cache_write schedule:
M = N = K = 16384
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [M, K])
B = T.match_buffer(b, [K, N])
C = T.match_buffer(c, [M, N])
for i, j, k in T.grid(M, K, N):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
ir_module = MyModule
sch = tvm.tir.Schedule(ir_module)
block_h = 32
block_w = 32
block_b = sch.get_block("B")
(i, j, k) = sch.get_loops(block_b)
by, yi = sch.split(i, factors=[None, block_h])
bx, xi = sch.split(j, factors=[None, block_w])
sch.reorder(by, bx, yi, xi)
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(yi, "threadIdx.y")
sch.bind(xi, "threadIdx.x")
block_cl = sch.cache_write(block_b, 0, "local")
sch.compute_at(block_cl, yi, preserve_unit_loops=True)
we also need to do compute_at because currently “C” is not contained in a thread environment or in the function arguments, but I got an issue when execute the compute_at schedule:
Traceback (most recent call last):
File "/workspace/v-leiwang3/tvm_gpu_gemm/tensorir_/1.blocked_gemm.py", line 77, in <module>
sch.compute_at(block_cl, yi, preserve_unit_loops=True)
File "/workspace/v-leiwang3/tvm/python/tvm/tir/schedule/_type_checker.py", line 274, in wrap
return func(*args, **kwargs)
File "/workspace/v-leiwang3/tvm/python/tvm/tir/schedule/schedule.py", line 1349, in compute_at
preserve_unit_loops,
File "/workspace/v-leiwang3/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
4: TVMFuncCall
3: _ZN3tvm7runtime13PackedFun
2: tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::tir::Schedule, tvm::tir::ScheduleNode, void, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, void>(void (tvm::tir::ScheduleNode::*)(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool))::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)#1}>(tvm::runtime::Registry::set_body_method<tvm::tir::Schedule, tvm::tir::ScheduleNode, void, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, void>(void (tvm::tir::ScheduleNode::*)(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool))::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool)
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(16384, 16384), "float32"], B: T.Buffer[(16384, 16384), "float32"], C: T.Buffer[(16384, 16384), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
C_local = T.alloc_buffer([16384, 16384], dtype="float32", scope="local")
for i_0 in T.thread_binding(512, thread="blockIdx.y"):
for j_0 in T.thread_binding(512, thread="blockIdx.x"):
for i_1 in T.thread_binding(32, thread="threadIdx.y"):
for j_1 in T.thread_binding(32, thread="threadIdx.x"):
for k in T.serial(16384):
with T.block("B"):
vi = T.axis.spatial(16384, i_0 * 32 + i_1)
vj = T.axis.spatial(16384, j_0 * 32 + j_1)
vk = T.axis.reduce(16384, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(C_local[vi, vj])
with T.init():
C_local[vi, vj] = T.float32(0)
C_local[vi, vj] = C_local[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0, ax1 in T.grid(16384, 16384):
# tir.Block#0
with T.block("C_local"):
^^^^^^^^^^^^^^^^^^^^^^^^
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]
Error message: The block tir.Block#0 is an output block
I rewrote the code according to my understanding and it works:
block_b = sch.get_block("B")
block_cl = sch.cache_write(block_b, 0, "local")
(i, j) = sch.get_loops(block_cl)
by, yi = sch.split(i, factors=[None, block_h])
bx, xi = sch.split(j, factors=[None, block_w])
sch.reorder(by, bx, yi, xi)
sch.bind(by, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(yi, "threadIdx.y")
sch.bind(xi, "threadIdx.x")
sch.compute_at(block_b, xi, preserve_unit_loops=True)
What I confused is the code that doesn’t work is a variant from my tensor expression experience:
nn = 16384
n = te.var("n")
n = tvm.runtime.convert(nn)
m, l = n, n
A = te.placeholder((l, n), dtype=_dtype, name="A")
B = te.placeholder((l, m), dtype=_dtype, name="B")
k = te.reduce_axis((0, l), name="k")
C = te.compute((m, n), lambda ii, jj: te.sum(
A[k, jj] * B[k, ii], axis=k), name="C")
# schedule
s = te.create_schedule(C.op)
CC = s.cache_write(C, "local")
write_code(
str(tvm.lower(s, [A, B, C], simple_mode=True)), "progress/1.cache.cu")
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")
block_h = 32
block_w = 32
bx, xi = s[C].split(C.op.axis[0], factor=(block_h))
by, yi = s[C].split(C.op.axis[1], factor=(block_w))
s[C].bind(bx, block_x)
s[C].bind(by, block_y)
s[C].reorder(bx, by, xi, yi)
s[C].bind(xi, thread_x)
s[C].bind(yi, thread_y)
s[CC].compute_at(s[C], yi)
What I confused in fact is why we need a rule to check weather a block is an output block when we do a compute_at schedule?
// tvm/src/tir/schedule/primitive/compute_at.cc:557
// Check condition 4): `block` is not an output block
if (is_compute_at) {
CheckNotOutputBlock(self, block_sref, scope_root_sref);
}