[RFC] TensorIR: A schedulable IR for TVM

I’m still stuck on reverse_compute_at which seems like a long name, and is still a bit too magical for me to understand

Yeah I once had some discussion with @tqchen and @spectrometerHBH that reverse_compute_at is too long and tedious, and I agree it will be great to find some better names (we are always bad at it admittedly)…

Perhaps @spectrometerHBH could give a specific example of using reverse_compute_at to avoid duplicate splitting?

vmap lets you do something then “zoom” into the area, forget entirely about the outer dimension, and focus on that

Yeah vmap is a powerful tool - it sounds more like “batching” (either static or dynamic), but its functionality can certainly go beyond that by being applied multiple times.

One problem of vmap if we want to introduce to TensorIR is the semantics: do we consider it as adding a data parallel loop on top of the producer block of each output buffer? I believe @tqchen has more thoughts on this.

write a separately scoped bit of code that doesn’t even know about the outer construction at all

The idea sounds pretty much like our “block isolation” construct. A “block” in the TensorIR means doing isolated computation without having to know the outer scope.

The difference is that the with scope you mentioned does not effectively change the IR, hence it is more like syntactic sugar to hint the schedule class that some primitives should operative locally, while the Block in TensorIR is an IR construct that really enforces some restrictions.

CC @Hzfengsy @spectrometerHBH @tqchen would love to hear your opinions

Yeah, I think your explanation is a good summary. I see what you mean about the TensorIR blocks

My understanding though is that the user doesn’t actually write TensorIR (except maybe to start), they still schedule with a separate language? The blocks in TIR seem really nice, but I still worry that the scheduling code itself also needs some ability to abstract. For instance the example here, 4. Matrix Multiplication — Dive into Deep Learning Compiler 0.1 documentation . It doesn’t seem like this changes that too much? There are so many axes in scope in this function at once, and it seems very hard to separate them all from each other.

1 Like

Thank you for such a valuable question.

Your understanding is correct. We still need a schedule language to schedule. That is because we need a simple API and abstraction for both human experts and automatical optimization (like AutoTVM, Ansor, and our new meta-schedule). Also, we try to keep user habbit, so we do not change API too much.

The critical challenge you are mentioned seems user experience especially loop axes. TensorIR is an eager schedule, which means every schedule primitive will change IR as soon as it exuecutes. Then, user can see axes and whole AST whenever they want, here is a simple example:

import tvm
from tvm import te, tir

A = te.placeholder((128, 128), name="A")
Update = te.compute((128, 128), lambda i, j: A[i, j] + 1, name="update")

"""Create PrimFunc from TE compute for further scheduling"""
func = te.create_func(Update)

"""Create TensorIR schedule"""
s = tir.create_schedule(func)

print(tvm.script.asscript(func))
"""Output

@tvm.script.tir
def func(var_A: ty.handle, var_update: ty.handle) -> None:
    A = tir.match_buffer(var_A, [128, 128], elem_offset=0, align=128, offset_factor=1)
    update = tir.match_buffer(var_update, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with tir.block([], "root") as []:
        tir.reads([])
        tir.writes([])
        for i0, i1 in tir.grid(128, 128):
            with tir.block([128, 128], "update") as [i, j]:
                tir.bind(i, i0)
                tir.bind(j, i1)
                tir.reads([A[i:(i + 1), j:(j + 1)]])
                tir.writes([update[i:(i + 1), j:(j + 1)]])
                update[i, j] = (A[i, j] + tir.float32(1))

"""
update = s.get_block("update")
x, y = s.get_axes(update)
print(x)
"""Output

for i0 = 0 to 128
"""
xo, xi = s.split(x, factor=32)
print(xo, xi, sep="\n")
"""Output

for i0_outer = 0 to 4
for i0_inner = 0 to 32
"""
print(x)
"""Output

(nullptr)
"""

Hi @Hzfengsy, Thanks for your answer, that makes a lot of sense.

However, what I am trying to communicate is that even with a new TIR and the ability to print it out at each step, the underlying scheduling language seems to be roughly the same in all of these examples (with the nice exception of removing the unnecessary s[C]). This is okay in all the examples in the thread that are very small.

What I am suggesting is that this might be a perfect time to think about simplifying the scheduling language as well. For instance the example here: 4. Matrix Multiplication — Dive into Deep Learning Compiler 0.1 documentation is in some sense the simplest scheduling example of TVM, and it is still extremely complicated. It has roughly 20 local variables in scope all with one letter names and many overlapping mutations (without even using autotvm).

I am curious about any ideas the new approach has to make this more robust, easier to extend, and understandable to non-experts?

1 Like

Thanks @srush! I totally agree!

Simplifying the scheduling language is definitely an important issue, especially when we want to approach broader audience.

An idea to make it happen is to compose schedule “primitives” into more high-level “composite rules”. For example, making “tiling + cache_read + compute_at” a single rule that user can directly call - it works like a single schedule instruction, but will be decomposed into many small “primitives” inside our system.

Let’s take a step forward, if we generalize the idea, like, our rule could work with no human interaction at all, then we can actually build an automatic scheduling framework on top of it :slight_smile: It is the idea from the Ansor’s paper, and I have some ongoing work about it as well

1 Like

On a side note to this conversation about new primitives, would the new TensorIR will include primitive “store_at” – the one present in Halide/Tiramisu ? – I just want to know if thats something in the roadmap :slight_smile: .

I am wondering whether the TesnorIR will be supported in IRBuilder? i.e. I can use the IRBuilder to emit the Block which contains the computation rule, and then directly call the new proposed schedule primitives to schedule block. From my understanding, TensorIR enhance the TIR with the ability to be scheduled, thus this should be supported in IRBuilder to directly construct scheduleable TIR.

@Hzfengsy What do we mean by “check” here? Is it something like a boundary check? What if tensorized operators provided by vendor have some alignment requirements (e.g. the start pointer must be multiple of 8/16).

This looks confusing to me because there is no example showing how to schedule the three blocks.

I’m curious if TIR’s Block construct is strong enough to support holistic fusion like Rammer and HFuse, where different Block has heterogeneous workload but we can fuse them in a single kernel. If so we can make such fusion a primitive in TIR schedule and enlarge the auto-scheduling search space.

Thanks, @yzh119. Currently, we have not considered the cross-kernel schedule in TensorIR. But it may be possible if we make it as one large kernel. Could you please show an example? (e.g. the IR before and after the schedule)

Thanks for such a great suggestion. Yes, we do support IRBuilder for TensorIR. However, it is not recommended. Because it is likely to generate illegal or opaque IR (which lacks some of the information). Besides, there are so many attributes/annotations (e.g block read/write regions and block iter_var) to be provided by users if you want to use native IRBuilder.

On the other hand, TVMScript can also represent ANY IR supported in TensorIR and TVM as far as I know. It provides extra syntax checks and sugar to make it easy to write a schedulable IR. So we strongly recommend you try TVMScript if possible. :slight_smile:

Yeah, to fully replace IRBuilder, there are still some missing elements of TVM script: meta programming and hygiene macros. Let’s consider the support after the upstreaming is done :slight_smile:

CC @tqchen

@junrushao If by “meta programming” you mean an ability to call python function from script to generate other code, and embed the generated code into the calling context, then YES, we absolutely need this! I think it’s called “splicing” or “unquote” in the literature.

The lack of such feature is what turned me away from hybrid script. Since there is basically no abstraction to compose multiple code to generate bigger code, I had to manually duplicate a lot of code. The resulting code is super ugly (just look at various shape functions). For example, tvm/_transform.py at 8e23806d2d522b71979d0a2730b38cc5c3bf6185 · apache/tvm · GitHub (this one I wrote)

For examples of kind of programming style we’ve developed for ir builder, see topi/cuda/sort.py or topi/cuda/nms.py. cc @mbrookhart

1 Like

@masahi Yeah. We should allow to embed IR fragments, functions that produce IR fragments, or replace some tokens with caller specified IR fragments into the script :slight_smile:

2 Likes

Just curious when will such a big feature mainline, is there a initial planning on it? Thanks, can’t wait for long to use it. Haha.

The upstreaming is in progress. You can track it at [RFC] TensorIR Scheduling Tracking Issue · Issue #7527 · apache/tvm · GitHub.

As you mentioned that the project is huge, it still needs some time (maybe 2-3 months) to finish. But we are trying our best. :slight_smile:

4 Likes

One issue in old schedule ops is we can not get the accurate bouds with inferbound, what will it be like in new schedule system? thanks.

Is this for merging two reduction stages into one stage? Thank you very much!

It’s not. merge_reduction is designed for merging init part and update part into one block with if branch

Hi, I am a learner at compiler stack, I have a few questions of Tensor IR

  1. instead of TE lower to TIR, what is the process from Relay directly to TIR? it will be better if you have example or tutorial image
  2. you mentioned that “Now, there is no stage during the schedule. Rather than lowering the schedule into TIR, we directly mutate the TIR itself.”. is this means there won’t have concept of “lower”, since no stage?
  3. do you have any details of difference between new Tensor IR and former TIR(like TVM 0.6) I am looking forward for your reply. Thank you!
1 Like