[RFC] TensorIR: A schedulable IR for TVM

Thanks for the proposal! This definitely opens more opportunities for performance optimization. Two questions for clarification:

  1. IIUC, based on the proposal and discussion, we will have both TE and TIR, but TE is more like a frontend wrapper of TIR to serve some users that prefer to write high-level DSL. Then, what will we do with the TE schedule primitives? Intuitively, we should still keep them; otherwise TE writers will have no way to schedule their computes, because they know nothing about TIR and blocks.

  2. Does this proposal support dynamic shape (i.e., Any)? For example, can we have something like:

    @tvm.hybrid.script
    def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
        C = tir.match_buffer(c, (1024, 1024), "float32")
        A = tir.match_buffer(a, (1024, Any), "float32")
        B = tir.match_buffer(b, (Any, 1024), "float32")
        reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))
    
        with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, vj, vk]:
            reducer.step(C[vi, vj], A[vi, vk] * B[vk, vj])
    
    s = tir.create_schedule(matmul)
    update = s.get_block("C")
    i, j, k = s.get_axes(update)
    i_o, i_i = s.split(i, bn)
    j_o, j_i = s.split(j, bn)
    k_o, k_i = s.split(k, 4)
    

    In this case, the length of vk (or k) is Any. Can we still apply split to it with a fixed factor

1 Like

Thanks for explanation. The relation between te and new tir is now more clear to me.

Thanks for clarification. It would be nice if we can use various methods to create tensor programs and use new tir to schedule them.

Good questions!

  1. As for as we know, we would like to let users use TensorIR schedule rather than TE schedule one we fully upstream the TensorIR. For three reasons:

    1. Just as you have mentioned, TE is a fronted wrapper, and it directly generates TIR with blocks. Somehow, TE is more like a sugar to define TIR.
    2. Most of the schedules and primitives in TensorIR are very similar to those in TE. The cost of learning TensorIR schedule is extremely low (maybe just one day).
    3. All primitives are based on the block (no stage concept in TensorIR schedule). It’s hard to keep the TE schedule with block
  2. Dynamic shapes are not supported now. However, thanks to our de-coupled primitives, it’s ok to support later.

Would love to see dynamic shape supported otherwise a large set of models can’t be backed by new TensorIR. :smiley:

3 Likes

So the scenario is like you can choose to use TE or TIR to write a compute, but if you choose TE, you have to first lower it to TIR and then add schedule primitives?

IIUC, it seems to me that this is nontrivial, because TIR was not written by human and you may need to first print it out to figure out how to schedule it. It sounds more straightforward to keep TE schedule as syntactic sugar. At least you can get the sense about how the schedule looks like by tracing the Python code.

3 Likes

Because there is a 1-1 mapping between te.Stage and Block. It should actually not be hard to use tir schedule to schedule a te compute generated PrimFunc (either by getting block via name, or pragmatically traverse the blocks like we do pragmatically on stages). But i agree that we can keep te.schedule for a bit.

Thanks for clarification. Make sense to me.

Thanks for the proposal! Just courious about the schuedule primitives like cache_write and cache_read, since there are no stages in TensorIR.

Thanks for your reply! @MinminSun

The cache_read/cache_write API accepts a Buffer and new scope as input, do some checks to ensure it brings no problem to read/write the Buffer into cache, and create new blocks to do the cache transfer.

Thanks for this RFC, I think it’s a great idea and will help solve a number of issues I’ve been facing recently. I’m particularly interested in what ‘tensorize’ will look like for this new IR. Could you give a snippet as an example?

I’m also interested in what the interaction of this will be with the loop partition pass. Will this mean that each partitioned loop will then be individually schedulable?

Thank you for your interest.

Tensorize in TensorIR is completely different from the TE ones. In TensorIR, we use two functions (desc_func and intrin_func) to define an intrinsic. Here would be an example of intrinsic (Note that TensorIR is still WIP, so the API may be changed).

@tvm.hybrid.script
def desc_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16])
    B = tir.match_buffer(b, [16, 16])
    C = tir.match_buffer(c, [16, 16])

    with tir.block([16, 16, tir.reduce_axis(0, 16)], "root") as [vi, vj, vk]:
        for i, j, k in tir.grid(16, 16, 16):
            with tir.block([16, 16, tir.reduce_axis(0, 16)], "update") as [vii, vjj, vkk]:
                tir.bind(vii, vi + i)
                tir.bind(vjj, vj + j)
                tir.bind(vkk, vk + k)
                C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]


@tvm.hybrid.script
def intrin_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16])
    B = tir.match_buffer(b, [16, 16])
    C = tir.match_buffer(c, [16, 16])

    with tir.block([16, 16, tir.reduce_axis(0, 16)], "root") as [vi, vj, vk]:
        tir.evaluate(tir.tvm_mma_sync(C.data, C.elem_offset // 256,
                                      A.data, A.elem_offset // 256,
                                      B.data, B.elem_offset // 256,
                                      C.data, C.elem_offset // 256,
                                      dtype="handle"))

Tensorize will match the sub-AST(usually is a block) with the desc_func, and then replace by intrin_func.

TensorIR is in the schedule level and has no coupling with low-level passes. However, we can directly schedule each loop directly and add primitives as you want. :slight_smile:

2 Likes

Thanks for this explanation. I’m interested if it might be possible to match tensor intrinsics with variable size? For example, Arm SVE introduces vector instructions of variable size.

Technically, it should support. However, due to time constraints, we have not yet supported.

Thanks for the proposal! Looks quite interesting!

Out of curiosity,

  1. The concat example you’ve shown where the original stage is represented in three blocks that seems to be assigning to the same buffer. I’m curious to know what if we want to move the concat (using compute_at, if possible ?) to a consumer of the concat’s output (to some loop of the consumer), how could it be done ? Will it create multiple blocks there as well ?

  2. Since the proposed TensorIR enables scoping of scheduling transformations in terms of blocks, will there be a prospect of representing a full relay graph in TensorIR ?

for concat, we could introduce a reverse inlining primitive that inlines elemenwise operations(after concat) back to the concat, which should be helpful in many cases.

While it is possible to represent a full graph, we would still imagine relay being super useful as a coarse grained repr for graph level opt. So that would suggest to have a continued effort on multi-level repr(relay and tir)

2 Likes

Thanks for the clarification! I concur that such a primitive should be useful and would allow more flexible compute movements.

Regarding the full graph, I agree that relay (along with optimization) being very useful. I was thinking whether there would be a benefit of lowering the full graph to tensorIR post relay optimization rather than lowering each primitive function. I guess this has to do with how AutoTVM/Ansor will allow the exploration of schedules but I got a feeling that could be scoped via the “blocks” that would otherwise lead to explosion of search space. (Looking from an AoT angle here).

Moreover, may be that could lay a foundation to inter-primitive function optimizations later.

1 Like

This is the right way to go. However I have two concern,

  1. How to fuse ops as much as possible? Basically fusion is copy propagation optimization in compilers, which is based on data flow analysis, but still lack of programming analysis in TVM now.
  2. TE tensorize can not handle some complex pattern matching, see https://github.com/apache/incubator-tvm/pull/1053, can we do 100% pattern matching in tir?

@xqdan Thank you for the valuable feedback! Fusion can be done automatically with some analysis provided in Ansor.

Do you have any other kind of analysis in mind that might be potentially useful?

1 Like

Is Fusion in Ansor based on tir? For other transforms, you may checkout here, that’s what we’ve done in AKG. I can explain some if you are intrested.

2 Likes