[TIR] Problem inlining addition into matmul block

For the integration of a new intrinsic, I would like to do a transformation to a TIR schedule to inline the addition of a bias into a matrix multiplication. I have created a very simple example to reproduce my problem, let’s assume to following PrimFunc:

@T.prim_func
def func(
    A: T.Buffer((16, 16), "int8"),
    B: T.Buffer((16, 16), "int8"),
    C: T.Buffer((16, 16), "int32"),
    D: T.Buffer((16, 16), "int32"),
    ) -> None:
 
    temp  = T.alloc_buffer((16, 16), dtype="int32")

    for i, j, k in T.grid(16, 16, 16):
        with T.block("multiply"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                temp[vi, vj] = T.int32(0)
            temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32")
    
    for i, j in T.grid(16, 16):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            D[vi, vj] = temp[vi, vj] + C[vi, vj] 

I want to transform it to achieve the following:

@T.prim_func
def expected_v1(
    A: T.Buffer((16, 16), "int8"),
    B: T.Buffer((16, 16), "int8"),
    C: T.Buffer((16, 16), "int32"),
    D: T.Buffer((16, 16), "int32"),
    ) -> None:
 
    temp  = T.alloc_buffer((16, 16), dtype="int32")

    for i, j, k in T.grid(16, 16, 16):
        with T.block("multiply"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                temp[vi, vj] = T.int32(0)
            temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + C[vi, vj]
    
    for i, j in T.grid(16, 16):
        with T.block("add"):
            vi, vj = T.axis.remap("SS", [i, j])
            D[vi, vj] = temp[vi, vj]

Or, ideally:

@T.prim_func
def expected_v2(
    A: T.Buffer((16, 16), "int8"),
    B: T.Buffer((16, 16), "int8"),
    C: T.Buffer((16, 16), "int32"),
    D: T.Buffer((16, 16), "int32"),
    ) -> None:
 
    temp  = T.alloc_buffer((16, 16), dtype="int32")

    for i, j, k in T.grid(16, 16, 16):
        with T.block("multiply"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                D[vi, vj] = C[vi, vj]
            D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32")

As you can see, mathematically all these computations are equivalent, so I would expect there is some way of getting there. But everything I tried failed. I tried to use compute_inline in the multiply block, reverse_comput_inline in the add block, decompose_reduction and then reverse_compute_inline…

Could someone confirm this is indeed not possible? And if that is the case, why? These seem like valid transformations, that should be possible in some way, but I am probably missing the reason why those aren’t possible.

Here is some example code to show some of what I tried (returns error Error message: The consumer block tir.Block#0 to be inlined is required to have only a single producer block, and the producer block should be a complete block who has only a single consumer):

if __name__ == "__main__":
    sch = tir.Schedule(func)

    mult_block = sch.get_block("multiply")
    init_block = sch.decompose_reduction(mult_block, sch.get_loops(mult_block)[-1])
    update_block = sch.get_block("multiply_update")
    add_block = sch.get_block("add")
    sch.cache_write(add_block, 0, "local")
    sch.reverse_compute_inline(add_block)

not immediately same as this code, but 2.4. TensorIR: Tensor Program Abstraction Case Study — Machine Learing Compiler 0.0.1 documentation should be relevant on reverse inlining

Thanks @tqchen . I am familiar with that resource, and I see that in Section 2.4.4.1 they actually arrive to something similar to what I already achieved using the following transformations:

sch = tir.Schedule(func)
mult_block = sch.get_block("multiply")
add_block = sch.get_block("add")
sch.reverse_compute_at(add_block, sch.get_loops(mult_block)[-1])
init_block = sch.decompose_reduction(mult_block, sch.get_loops(mult_block)[-1])
update_block = sch.get_block("multiply_update")

Using this, I arrive at something like this:

class Module:
    @T.prim_func
    def main(A: T.Buffer((16, 16), "int8"), B: T.Buffer((16, 16), "int8"), C: T.Buffer((16, 16), "int32"), D: T.Buffer((16, 16), "int32")):
        T.func_attr({"global_symbol": "func"})
        # with T.block("root"):
        temp = T.alloc_buffer((16, 16), "int32")
        for i, j in T.grid(16, 16):
            with T.block("multiply_init"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads()
                T.writes(temp[vi, vj])
                temp[vi, vj] = 0
            for k in range(16):
                with T.block("multiply_update"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    T.reads(temp[vi, vj], A[vi, vk], B[vj, vk])
                    T.writes(temp[vi, vj])
                    temp[vi, vj] = temp[vi, vj] + T.Cast("int32", A[vi, vk]) * T.Cast("int32", B[vj, vk])
                with T.block("add"):
                    vi, vj = T.axis.remap("SS", [i, j])
                    T.reads(temp[vi, vj], C[vi, vj])
                    T.writes(D[vi, vj])
                    D[vi, vj] = temp[vi, vj] + C[vi, vj]

Now, I would like to merge both blocks “add” and “multiply_update”, but it seems this is not possible.

  • Using reverse_compute_inline on the “add” block gives the error saying that the consumer can only have a single producer block, and in this case it has 2 (because of the init block).
  • Using compute inline on the “multiply_update” block cannot be done, because update_block is not a complete block.
  • Using sch.blockize([add_block, update_block]) gives a criptic error: InternalError: Check failed: (0 <= i && i < p->size_) is false: IndexError: indexing 2 on an array of size 2

I see, indeed. For this case, the code should already be efficient enough. The temp will get narrowed into a size 1 buffer and then into register during codegen

Indeed, that could be the case. But for example, if I would like to tensorize this TIR to leverage specific instructions for a particular HW, having the opportunity to actually merge the bias into the update block could be highly beneficial.

It is very common, for example, for vector ISAs to have some kind of macc instruction. I could, for example, preload the bias into a register, and then accumulate the multiplication directly onto the preloaded register.

i see, agree about the point, unfortunately it was not supported by the current transformation, but if you are interested in adding a new primitives(e.g. fold reduction bias). happy to bring it in

1 Like