【TIR】 cache_write + compute_at "The block tir.Block#0 is an output block"

代码如下:

@T.prim_func
def static_add(a:T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a,(128))
    B = T.match_buffer(b,(128))
    C = T.match_buffer(c,(128))
    for i in range(128):
        with T.block("C"):
          C[i]=B[i]+A[i]

sch = tir.Schedule(static_add, debug_mask="all")
i,= sch.get_loops("C")
io,inn = sch.split(i,factors=[4,None])
lc = sch.cache_write("C", 0, "local")
sch.compute_at(lc,io)
sch.mod.show()

报错如下

The block tir.Block#0 is an output block

prim_func是:

class Module:
    @T.prim_func
    def main(A: T.Buffer[128, "float32"], B: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]):
        # body
        # with T.block("root")
        C_local = T.alloc_buffer([128], dtype="float32", scope="local")
        for i_0, i_1 in T.grid(4, 32):
            with T.block("C"):
                T.reads(B[i_0 * 32 + i_1], A[i_0 * 32 + i_1])
                T.writes(C_local[i_0 * 32 + i_1])
                C_local[i_0 * 32 + i_1] = B[i_0 * 32 + i_1] + A[i_0 * 32 + i_1]
        for ax0 in T.serial(128):
            # tir.Block#0
            with T.block("C_local"):
            ^^^^^^^^^^^^^^^^^^^^^^^^
                v0 = T.axis.spatial(128, ax0)
                T.reads(C_local[v0])
                T.writes(C[v0])
                C[v0] = C_local[v0]

see see primitive reserve_compute_at

OK, 用reverse_compute_at 问题解决了,感谢