Hello,
I am very interested in how compute_at() would transform the TIR. Could anyone at least give an example on this?
Hello,
I am very interested in how compute_at() would transform the TIR. Could anyone at least give an example on this?
Thanks for your interest. To be general, TensorIR compute_at does the same job as TE. One simple example is here:
@tvm.script.tir
def func_before(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
C = tir.match_buffer(c, (128, 128))
B = tir.alloc_buffer((128, 128))
for i0 in range(0, 128):
for j0 in range(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
for i1 in range(0, 128):
for j1 in range(0, 128):
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
@tvm.script.tir
def func_after(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
C = tir.match_buffer(c, (128, 128))
B = tir.alloc_buffer((128, 128))
for i0 in range(0, 128):
for j0 in range(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
for j1 in range(0, 128):
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
if __name__ == "__main__":
s = tir.Schedule(func_before, debug_mode=True)
B = s.get_block("B")
C = s.get_block("C")
outer, _ = s.get_loops(C)
s.compute_at(B, outer)
tvm.ir.assert_structural_equal(func_after, s.mod["main"])
In another further step, I would like to talk about the algorithm briefly. Compute_at
will try to cover the consuming region by moving the block under the loop and regenerating loop nests.
In the above example, We want to compute_at
block B
under loop i1
:
buffer B
under the loop i1
, which is B[i, 0:128]
block B
will produce only one element of B[i, j]
at one time.block B
128 times with the new loop j0
loop j0
with block B
and update the AST