Combining/Fusing TensorIR blocks from different primfuncs

Hi everyone,

does anyone know an elegant way of combining/fusing two blocks from different TIR primfuncs?

Here an example of two primfuncs one and two. I want to connect one’s B and two’s C buffer in a new primfunc, i.e.

@T.prim_func
def one(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (16,))
    B = T.match_buffer(b, (16,))
    with T.block("root"):
        for i in T.serial(16):
            B[i] = A[i]


@T.prim_func
def two(c: T.handle, d: T.handle) -> None:
    C = T.match_buffer(c, (16,))
    D = T.match_buffer(d, (16,))
    with T.block("root"):
        for i in T.serial(16):
            D[i] = C[i]

Expected result (something like this):

@T.prim_func
def one_two(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, (16,))
    B = T.alloc_buffer(
        (16,),
    )
    D = T.match_buffer(d, (16,))

    with T.block("root_one"):
        for i in T.serial(16):
            B[i] = A[i]
    with T.block("root_two"):
        for i in T.serial(16):
            D[i] = B[i]

Thanks

CC @Hzfengsy @junrushao @yuchenj @cyx

we have a relax pass called FuseTIR that can do this

1 Like

The approach in unity is at https://github.com/apache/tvm/blob/unity/src/relax/transform/fuse_tir.cc

Besides, I wrote something simple years ago, but not upstream as there is no use case at that time:Concatenate PrimFuncs · Hzfengsy/tvm@5a5a56e · GitHub

2 Likes

@Hzfengsy, the pass you wrote is very useful! Thanks :+1: :+1: My usecase is to compose a primfunc from multiple TIR primfuncs that I defined in TVM script, very similar to the usecase in Relax. However, my current flow still relies on Relay.

Any chance this feature will make it into main?

CC: @tqchen

hopefully as we bring unity to to main [DISCUSS] TVM Core Strategy for Emerging Needs this feature will be available. in the meantime, feel free to try out unity branch, it syncs with main periodically so contains all the relay features there