Thanks @tqchen . I am familiar with that resource, and I see that in Section 2.4.4.1 they actually arrive to something similar to what I already achieved using the following transformations:
sch = tir.Schedule(func)
mult_block = sch.get_block("multiply")
add_block = sch.get_block("add")
sch.reverse_compute_at(add_block, sch.get_loops(mult_block)[-1])
init_block = sch.decompose_reduction(mult_block, sch.get_loops(mult_block)[-1])
update_block = sch.get_block("multiply_update")
Using this, I arrive at something like this:
class Module:
@T.prim_func
def main(A: T.Buffer((16, 16), "int8"), B: T.Buffer((16, 16), "int8"), C: T.Buffer((16, 16), "int32"), D: T.Buffer((16, 16), "int32")):
T.func_attr({"global_symbol": "func"})
# with T.block("root"):
temp = T.alloc_buffer((16, 16), "int32")
for i, j in T.grid(16, 16):
with T.block("multiply_init"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads()
T.writes(temp[vi, vj])
temp[vi, vj] = 0
for k in range(16):
with T.block("multiply_update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(temp[vi, vj], A[vi, vk], B[vj, vk])
T.writes(temp[vi, vj])
temp[vi, vj] = temp[vi, vj] + T.Cast("int32", A[vi, vk]) * T.Cast("int32", B[vj, vk])
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(temp[vi, vj], C[vi, vj])
T.writes(D[vi, vj])
D[vi, vj] = temp[vi, vj] + C[vi, vj]
Now, I would like to merge both blocks “add” and “multiply_update”, but it seems this is not possible.
- Using reverse_compute_inline on the “add” block gives the error saying that the consumer can only have a single producer block, and in this case it has 2 (because of the init block).
- Using compute inline on the “multiply_update” block cannot be done, because update_block is not a complete block.
- Using
sch.blockize([add_block, update_block])gives a criptic error:InternalError: Check failed: (0 <= i && i < p->size_) is false: IndexError: indexing 2 on an array of size 2