[TIR] Problem inlining addition into matmul block

Thanks @tqchen . I am familiar with that resource, and I see that in Section 2.4.4.1 they actually arrive to something similar to what I already achieved using the following transformations:

sch = tir.Schedule(func)
mult_block = sch.get_block("multiply")
add_block = sch.get_block("add")
sch.reverse_compute_at(add_block, sch.get_loops(mult_block)[-1])
init_block = sch.decompose_reduction(mult_block, sch.get_loops(mult_block)[-1])
update_block = sch.get_block("multiply_update")

Using this, I arrive at something like this:

class Module:
    @T.prim_func
    def main(A: T.Buffer((16, 16), "int8"), B: T.Buffer((16, 16), "int8"), C: T.Buffer((16, 16), "int32"), D: T.Buffer((16, 16), "int32")):
        T.func_attr({"global_symbol": "func"})
        # with T.block("root"):
        temp = T.alloc_buffer((16, 16), "int32")
        for i, j in T.grid(16, 16):
            with T.block("multiply_init"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads()
                T.writes(temp[vi, vj])
                temp[vi, vj] = 0
            for k in range(16):
                with T.block("multiply_update"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    T.reads(temp[vi, vj], A[vi, vk], B[vj, vk])
                    T.writes(temp[vi, vj])
                    temp[vi, vj] = temp[vi, vj] + T.Cast("int32", A[vi, vk]) * T.Cast("int32", B[vj, vk])
                with T.block("add"):
                    vi, vj = T.axis.remap("SS", [i, j])
                    T.reads(temp[vi, vj], C[vi, vj])
                    T.writes(D[vi, vj])
                    D[vi, vj] = temp[vi, vj] + C[vi, vj]

Now, I would like to merge both blocks “add” and “multiply_update”, but it seems this is not possible.

  • Using reverse_compute_inline on the “add” block gives the error saying that the consumer can only have a single producer block, and in this case it has 2 (because of the init block).
  • Using compute inline on the “multiply_update” block cannot be done, because update_block is not a complete block.
  • Using sch.blockize([add_block, update_block]) gives a criptic error: InternalError: Check failed: (0 <= i && i < p->size_) is false: IndexError: indexing 2 on an array of size 2