代码如下:
@T.prim_func
def static_add(a:T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a,(128))
B = T.match_buffer(b,(128))
C = T.match_buffer(c,(128))
for i in range(128):
with T.block("C"):
C[i]=B[i]+A[i]
sch = tir.Schedule(static_add, debug_mask="all")
i,= sch.get_loops("C")
io,inn = sch.split(i,factors=[4,None])
lc = sch.cache_write("C", 0, "local")
sch.compute_at(lc,io)
sch.mod.show()
报错如下
The block tir.Block#0 is an output block
prim_func是:
class Module:
@T.prim_func
def main(A: T.Buffer[128, "float32"], B: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]):
# body
# with T.block("root")
C_local = T.alloc_buffer([128], dtype="float32", scope="local")
for i_0, i_1 in T.grid(4, 32):
with T.block("C"):
T.reads(B[i_0 * 32 + i_1], A[i_0 * 32 + i_1])
T.writes(C_local[i_0 * 32 + i_1])
C_local[i_0 * 32 + i_1] = B[i_0 * 32 + i_1] + A[i_0 * 32 + i_1]
for ax0 in T.serial(128):
# tir.Block#0
with T.block("C_local"):
^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(128, ax0)
T.reads(C_local[v0])
T.writes(C[v0])
C[v0] = C_local[v0]