Hi everyone,
does anyone know an elegant way of combining/fusing two blocks from different TIR primfuncs?
Here an example of two primfuncs one
and two
. I want to connect one’s B and two’s C buffer in a new primfunc, i.e.
@T.prim_func
def one(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (16,))
B = T.match_buffer(b, (16,))
with T.block("root"):
for i in T.serial(16):
B[i] = A[i]
@T.prim_func
def two(c: T.handle, d: T.handle) -> None:
C = T.match_buffer(c, (16,))
D = T.match_buffer(d, (16,))
with T.block("root"):
for i in T.serial(16):
D[i] = C[i]
Expected result (something like this):
@T.prim_func
def one_two(a: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (16,))
B = T.alloc_buffer(
(16,),
)
D = T.match_buffer(d, (16,))
with T.block("root_one"):
for i in T.serial(16):
B[i] = A[i]
with T.block("root_two"):
for i in T.serial(16):
D[i] = B[i]
Thanks