Compute_at() on tir

Hello,

I am very interested in how compute_at() would transform the TIR. Could anyone at least give an example on this?

CC: @Hzfengsy @spectrometerHBH

Thanks for your interest. To be general, TensorIR compute_at does the same job as TE. One simple example is here:

@tvm.script.tir
def func_before(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    C = tir.match_buffer(c, (128, 128))

    B = tir.alloc_buffer((128, 128))
    for i0 in range(0, 128):
        for j0 in range(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                B[vi, vj] = A[vi, vj] * 2.0
    for i1 in range(0, 128):
        for j1 in range(0, 128):
            with tir.block([128, 128], "C") as [vi, vj]:
                C[vi, vj] = B[vi, vj] + 1.0

@tvm.script.tir
def func_after(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    C = tir.match_buffer(c, (128, 128))
    B = tir.alloc_buffer((128, 128))

    for i0 in range(0, 128):
        for j0 in range(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                B[vi, vj] = A[vi, vj] * 2.0
        for j1 in range(0, 128):
            with tir.block([128, 128], "C") as [vi, vj]:
                C[vi, vj] = B[vi, vj] + 1.0

if __name__ == "__main__":
    s = tir.Schedule(func_before, debug_mode=True)
    B = s.get_block("B")
    C = s.get_block("C")
    outer, _ = s.get_loops(C)
    s.compute_at(B, outer)
    tvm.ir.assert_structural_equal(func_after, s.mod["main"])

In another further step, I would like to talk about the algorithm briefly. Compute_at will try to cover the consuming region by moving the block under the loop and regenerating loop nests.

In the above example, We want to compute_at block B under loop i1:

  1. We first collect the consuming region of buffer B under the loop i1, which is B[i, 0:128]
  2. We find the block B will produce only one element of B[i, j] at one time.
  3. So, we need to “run” block B 128 times with the new loop j0
  4. Generate the loop j0 with block B and update the AST
2 Likes