For the integration of a new intrinsic, I would like to do a transformation to a TIR schedule to inline the addition of a bias into a matrix multiplication. I have created a very simple example to reproduce my problem, let’s assume to following PrimFunc:
@T.prim_func
def func(
A: T.Buffer((16, 16), "int8"),
B: T.Buffer((16, 16), "int8"),
C: T.Buffer((16, 16), "int32"),
D: T.Buffer((16, 16), "int32"),
) -> None:
temp = T.alloc_buffer((16, 16), dtype="int32")
for i, j, k in T.grid(16, 16, 16):
with T.block("multiply"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
temp[vi, vj] = T.int32(0)
temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32")
for i, j in T.grid(16, 16):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = temp[vi, vj] + C[vi, vj]
I want to transform it to achieve the following:
@T.prim_func
def expected_v1(
A: T.Buffer((16, 16), "int8"),
B: T.Buffer((16, 16), "int8"),
C: T.Buffer((16, 16), "int32"),
D: T.Buffer((16, 16), "int32"),
) -> None:
temp = T.alloc_buffer((16, 16), dtype="int32")
for i, j, k in T.grid(16, 16, 16):
with T.block("multiply"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
temp[vi, vj] = T.int32(0)
temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + C[vi, vj]
for i, j in T.grid(16, 16):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = temp[vi, vj]
Or, ideally:
@T.prim_func
def expected_v2(
A: T.Buffer((16, 16), "int8"),
B: T.Buffer((16, 16), "int8"),
C: T.Buffer((16, 16), "int32"),
D: T.Buffer((16, 16), "int32"),
) -> None:
temp = T.alloc_buffer((16, 16), dtype="int32")
for i, j, k in T.grid(16, 16, 16):
with T.block("multiply"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
D[vi, vj] = C[vi, vj]
D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32")
As you can see, mathematically all these computations are equivalent, so I would expect there is some way of getting there. But everything I tried failed. I tried to use compute_inline in the multiply block, reverse_comput_inline in the add block, decompose_reduction and then reverse_compute_inline…
Could someone confirm this is indeed not possible? And if that is the case, why? These seem like valid transformations, that should be possible in some way, but I am probably missing the reason why those aren’t possible.
Here is some example code to show some of what I tried (returns error Error message: The consumer block tir.Block#0 to be inlined is required to have only a single producer block, and the producer block should be a complete block who has only a single consumer
):
if __name__ == "__main__":
sch = tir.Schedule(func)
mult_block = sch.get_block("multiply")
init_block = sch.decompose_reduction(mult_block, sch.get_loops(mult_block)[-1])
update_block = sch.get_block("multiply_update")
add_block = sch.get_block("add")
sch.cache_write(add_block, 0, "local")
sch.reverse_compute_inline(add_block)