[RFC] TensorIR: A schedulable IR for TVM

Thank you for your interest.

Tensorize in TensorIR is completely different from the TE ones. In TensorIR, we use two functions (desc_func and intrin_func) to define an intrinsic. Here would be an example of intrinsic (Note that TensorIR is still WIP, so the API may be changed).

@tvm.hybrid.script
def desc_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16])
    B = tir.match_buffer(b, [16, 16])
    C = tir.match_buffer(c, [16, 16])

    with tir.block([16, 16, tir.reduce_axis(0, 16)], "root") as [vi, vj, vk]:
        for i, j, k in tir.grid(16, 16, 16):
            with tir.block([16, 16, tir.reduce_axis(0, 16)], "update") as [vii, vjj, vkk]:
                tir.bind(vii, vi + i)
                tir.bind(vjj, vj + j)
                tir.bind(vkk, vk + k)
                C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]


@tvm.hybrid.script
def intrin_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16])
    B = tir.match_buffer(b, [16, 16])
    C = tir.match_buffer(c, [16, 16])

    with tir.block([16, 16, tir.reduce_axis(0, 16)], "root") as [vi, vj, vk]:
        tir.evaluate(tir.tvm_mma_sync(C.data, C.elem_offset // 256,
                                      A.data, A.elem_offset // 256,
                                      B.data, B.elem_offset // 256,
                                      C.data, C.elem_offset // 256,
                                      dtype="handle"))

Tensorize will match the sub-AST(usually is a block) with the desc_func, and then replace by intrin_func.

TensorIR is in the schedule level and has no coupling with low-level passes. However, we can directly schedule each loop directly and add primitives as you want. :slight_smile:

2 Likes