[RFC] TensorIR: A schedulable IR for TVM

@merrymercy Good question! Here’s an example of TIR’s schedule.

s = tir.create_schedule(original_func)

update = s.get_block("C")
i, j, k = s.get_axes(update)
i_o, i_i = s.split(i, bn)
j_o, j_i = s.split(j, bn)
k_o, k_i = s.split(k, 4)
s.reorder(i_o, j_o, k_o, k_i, i_i, j_i)

TIR’s schedule is not totally stateless. Scope info, dependency graph info is actively maintained during the scheduling process in class Schedule. We don’t calculate them each time we apply a new primitive. After lowering to TIR without blocks, we don’t maintain these info any more since it is not schedulable.

All in all, it is good to run the benchmark to compare them in practice. I hope I understand your question correctly. :smile:

2 Likes

When I read this RFC, I was confused because I read that the compilation flow currently goes from

TF/PyTorch/ONNX -> Relay -> TE (based on TE schedule) -> TIR -> C++/CUDA

And then I read:

TensorIR is a brand new low-level IR with full scheduling support. Here are some […]

And then see references to TIR throughout the rest of the RFC, which seem unclear to me if these references are to the new TensorIR, or the old TIR.

Can we clarify both here and in the code base moving forward whether we are referring to TensorIR or the old TIR? I think there are two ways of doing this going forward and we should explicitly pick one:

  1. Have TIR refer to the old version of TIR, and always use TensorIR when talking about the new IR.
  2. Be clear that TensorIR moving forward will be replacing the old TIR, so that any reference to TIR could be referring to the new or old version (and should be clarified if not obvious from context).

Someone more knowledgeable than me should pick one of those (or correct me and/or point out other options if I’ve gotten anything wrong here).

1 Like

Yes, the ambiguity is something I was struggling with too, when having a conversation. May I ask what does the “T” of old TIR stands for ? TVM ?

TensorIR can be viewed as major feature enhancements(upgrade) to the TIR in master. That is why TensorIR and TIR are used interchangeably as they are supposed to be so.

Some of the elements like multidimensional buffer load and tvm script are already present as part of the unified IR effort.

The upgrade will happen quite naturally. With the TIR continue to support current code as as they are, and gains new scheduling capabilities with the new block constructs.

Hi,

I was wondering on the status of this RFC. Is there any PR or work in progress available?

Thanks

Hi TVM,

This idea is so cool, I think it is going to make it possible for mortals to use TVM effectively.

I have a couple of questions about the snippet of the scheduling language.

The three issues I have when programming TVM are:

  • Too many variables in scope with meaningless names, and forgetting where they can from.
  • Losing track of which axes need to be split identically
  • Not understanding the semantics of compute_at and how it tells which axes line up.

It seems like maybe this fixes a couple of these. However the following still bugs me a bit:

s = tir.create_schedule(matmul)
update = s.get_block("C")
i, j, k = s.get_axes(update)
i_o, i_i = s.split(i, bn)
j_o, j_i = s.split(j, bn)
k_o, k_i = s.split(k, 4)
s.reorder(i_o, j_o, k_o, k_i, i_i, j_i)

Curious why:

  • There are strings for get_block?
  • Why not have split go into a named out/in tuple to discourage this _ style naming. It gets so messy so quickly.
  • does this propose to fix the issue of having to repeat the identical splits for things like shared and local buffers that need to be done later in the code. (In order for compute at to work)

Thanks so much! Love the library, want to get to a point where I can teach it to my students effectively. /Sasha

1 Like

Hey @srush,

Thanks for your valuable feedback! Please allow me to try to explain the design rationale below:

Q1. There are strings for get_block?

Yeah. One design principle we hold for TensorIR is that all needed for scheduling is contained in the TensorIR python syntax, so that there is no “mysteriously hidden” information any more. Given a TensorIR script in python, we can schedule on it.

In the particular case, the string here is the name of the block in the text format.

Here is the example:

@tvm.script.tir
def matmul(x: ty.handle, y: ty.handle, z: ty.handle) -> None:
    X = tir.match_buffer(x, [128, 128], "float32")
    Y = tir.match_buffer(y, [128, 128], "float32")
    Z = tir.match_buffer(z, [128, 128], "float32")
    # ⬇️  name of the block is "Z"
    with tir.block([128, 128, tir.reduce_axis(0, 128)], "Z") as [i, j, k]:
        with tir.init():
            Z[i, j] = tir.float32(0)
        Z[i, j] = Z[i, j] + (X[i, k] * Y[k, j])

"""Print TensorIR in python syntax"""
print(tvm.script.asscript(matmul))
"""
Create a schedule
  Scehdule is TensorIR => TensorIR
  so we can print the latest TensorIR at any step of scheduling
"""
sch = tvm.tir.create_schedule(matmul)
ax0, ax1 = sch.split(...)
"""Print TensorIR at any step"""
print(tvm.script.asscript(sch.func))
sch.fuse(...)
print(tvm.script.asscript(sch.func))
"""
We can print the loops and blocks too into syntax like:
     for i in range(0, 100)
"""
print(ax0)

Q2. Why not have split go into a named out/in tuple to discourage this _ style naming. It gets so messy so quickly

Agreed. We do find the _ style naming annoying in our experiments, especially when scheduling gets complicated and it generates really horrible names like i0_outer_outer_outer_outer_i1_outer_outer_outer_outer_fused_i2_outer_outer_outer_outer_fused_i3_outer_outer_outer_outer_fused.

We have some internal strawman proposals for syntactic sugars, but converged to a perfect solution yet.

Solution 1. Allow splitting by multiple factors + accept customized naming of axes.

a, b, c, d = s.split(axes, factors=[None, 2, 4, 8], names=["a", "b", "c", "d"])

Note that the None in the factors means letting the schedule to infer.

Solution 2. Allow splitting by multiple factors + einsum-like style API + get axes by name

i0, _, _, _ = s.split("i => i0, i1, i2, i3", factors=[None, 2, 4, 8])
# or allow retrieve by name
i0 = s.get_axis(name="i0")

We are open to new proposals too :slight_smile:

Q3. Does this propose to fix the issue of having to repeat the identical splits for things like shared and local buffers that need to be done later in the code. (In order for compute at to work)

In short, yes, and it is solved by introducing a new scheduling primitive reverse_compute_at.

Definition of compute_at. Given a producer and a consumer, compute_at allows to compute part of the producer’s region under one of the consumer’s loop.

Definition of reverse_compute_at. Given a producer and a consumer, reverse_compute_at allows to compute part of the consumer’s region under one of the producer’s loop.

Our typical usecase. We have a heavy producer, like conv2d, and a light consumer, like what is generated by cache_write, or a small ReLU, and we want to fuse them for better locality.

Why bother duplicated splitting in compute_at. With compute_at, we are moving the producer under a loop of the consumer. First, user has to split the consumer, otherwise the user doesn’t even know which axis to be computed at; Second, user has to split the producer so that the other tiles are correctly positions - that is why it is so lengthy and tedious.

Why reverse_compute_at avoids duplicated splitting. In this case, we only need to split the producer, then put the consumer under a specific loop of the producer. Then we don’t have to do any duplicate splitting :slight_smile:

CC: @tqchen @Hzfengsy @spectrometerHBH @vinx13 @masahi

2 Likes

Neat! This is a really good explanation. I think I get most everything you are explaining. (I’m still stuck on reverse_compute_at which seems like a long name, and is still a bit too magical for me to understand.)

In terms of splitting / reordering, my goofy thought is that my favorite construct for tensor abstractions is vmap in Jax. What makes programming tensors so hard is that keeping 6 tensor dimensions in your head is really hard. vmap lets you do something then “zoom” into the area, forget entirely about the outer dimension, and focus on that.

When writing tvm code for matrix multiply with double buffering. I would really like to 1) first decide on my tiling of the output, split, assign to blocks and threads, and then 2) write a separately scoped bit of code that doesn’t even know about the outer construction at all. Ideally, I would create my outer scope, vmap in, then all my buffers are automatically “reverse_compute_at” / vmapped, and then create my inner setup.

I don’t know if this totally works but this would be my ideal:

l_o = split_out(C, l) # only exposes the outer
n_o = split_out(C, n)

with prefix(ll_o, nn_o, tensors=[A, B],  threads=[]) as s2:
     s2.cache_read(... ) # this cache read is now local computed at here
     l, n = s2.axes(C) # these are now the inner splitted axes
     m = s2.axes(A) # A's outer axes are now invisible
     s2.reorder(n, l) # only touches the visible axes. 

(Maybe instead of a with this is an inner function like in jax)

1 Like

I’m still stuck on reverse_compute_at which seems like a long name, and is still a bit too magical for me to understand

Yeah I once had some discussion with @tqchen and @spectrometerHBH that reverse_compute_at is too long and tedious, and I agree it will be great to find some better names (we are always bad at it admittedly)…

Perhaps @spectrometerHBH could give a specific example of using reverse_compute_at to avoid duplicate splitting?

vmap lets you do something then “zoom” into the area, forget entirely about the outer dimension, and focus on that

Yeah vmap is a powerful tool - it sounds more like “batching” (either static or dynamic), but its functionality can certainly go beyond that by being applied multiple times.

One problem of vmap if we want to introduce to TensorIR is the semantics: do we consider it as adding a data parallel loop on top of the producer block of each output buffer? I believe @tqchen has more thoughts on this.

write a separately scoped bit of code that doesn’t even know about the outer construction at all

The idea sounds pretty much like our “block isolation” construct. A “block” in the TensorIR means doing isolated computation without having to know the outer scope.

The difference is that the with scope you mentioned does not effectively change the IR, hence it is more like syntactic sugar to hint the schedule class that some primitives should operative locally, while the Block in TensorIR is an IR construct that really enforces some restrictions.

CC @Hzfengsy @spectrometerHBH @tqchen would love to hear your opinions

Yeah, I think your explanation is a good summary. I see what you mean about the TensorIR blocks

My understanding though is that the user doesn’t actually write TensorIR (except maybe to start), they still schedule with a separate language? The blocks in TIR seem really nice, but I still worry that the scheduling code itself also needs some ability to abstract. For instance the example here, 4. Matrix Multiplication — Dive into Deep Learning Compiler 0.1 documentation . It doesn’t seem like this changes that too much? There are so many axes in scope in this function at once, and it seems very hard to separate them all from each other.

1 Like

Thank you for such a valuable question.

Your understanding is correct. We still need a schedule language to schedule. That is because we need a simple API and abstraction for both human experts and automatical optimization (like AutoTVM, Ansor, and our new meta-schedule). Also, we try to keep user habbit, so we do not change API too much.

The critical challenge you are mentioned seems user experience especially loop axes. TensorIR is an eager schedule, which means every schedule primitive will change IR as soon as it exuecutes. Then, user can see axes and whole AST whenever they want, here is a simple example:

import tvm
from tvm import te, tir

A = te.placeholder((128, 128), name="A")
Update = te.compute((128, 128), lambda i, j: A[i, j] + 1, name="update")

"""Create PrimFunc from TE compute for further scheduling"""
func = te.create_func(Update)

"""Create TensorIR schedule"""
s = tir.create_schedule(func)

print(tvm.script.asscript(func))
"""Output

@tvm.script.tir
def func(var_A: ty.handle, var_update: ty.handle) -> None:
    A = tir.match_buffer(var_A, [128, 128], elem_offset=0, align=128, offset_factor=1)
    update = tir.match_buffer(var_update, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with tir.block([], "root") as []:
        tir.reads([])
        tir.writes([])
        for i0, i1 in tir.grid(128, 128):
            with tir.block([128, 128], "update") as [i, j]:
                tir.bind(i, i0)
                tir.bind(j, i1)
                tir.reads([A[i:(i + 1), j:(j + 1)]])
                tir.writes([update[i:(i + 1), j:(j + 1)]])
                update[i, j] = (A[i, j] + tir.float32(1))

"""
update = s.get_block("update")
x, y = s.get_axes(update)
print(x)
"""Output

for i0 = 0 to 128
"""
xo, xi = s.split(x, factor=32)
print(xo, xi, sep="\n")
"""Output

for i0_outer = 0 to 4
for i0_inner = 0 to 32
"""
print(x)
"""Output

(nullptr)
"""

Hi @Hzfengsy, Thanks for your answer, that makes a lot of sense.

However, what I am trying to communicate is that even with a new TIR and the ability to print it out at each step, the underlying scheduling language seems to be roughly the same in all of these examples (with the nice exception of removing the unnecessary s[C]). This is okay in all the examples in the thread that are very small.

What I am suggesting is that this might be a perfect time to think about simplifying the scheduling language as well. For instance the example here: 4. Matrix Multiplication — Dive into Deep Learning Compiler 0.1 documentation is in some sense the simplest scheduling example of TVM, and it is still extremely complicated. It has roughly 20 local variables in scope all with one letter names and many overlapping mutations (without even using autotvm).

I am curious about any ideas the new approach has to make this more robust, easier to extend, and understandable to non-experts?

1 Like

Thanks @srush! I totally agree!

Simplifying the scheduling language is definitely an important issue, especially when we want to approach broader audience.

An idea to make it happen is to compose schedule “primitives” into more high-level “composite rules”. For example, making “tiling + cache_read + compute_at” a single rule that user can directly call - it works like a single schedule instruction, but will be decomposed into many small “primitives” inside our system.

Let’s take a step forward, if we generalize the idea, like, our rule could work with no human interaction at all, then we can actually build an automatic scheduling framework on top of it :slight_smile: It is the idea from the Ansor’s paper, and I have some ongoing work about it as well

1 Like

On a side note to this conversation about new primitives, would the new TensorIR will include primitive “store_at” – the one present in Halide/Tiramisu ? – I just want to know if thats something in the roadmap :slight_smile: .

I am wondering whether the TesnorIR will be supported in IRBuilder? i.e. I can use the IRBuilder to emit the Block which contains the computation rule, and then directly call the new proposed schedule primitives to schedule block. From my understanding, TensorIR enhance the TIR with the ability to be scheduled, thus this should be supported in IRBuilder to directly construct scheduleable TIR.

@Hzfengsy What do we mean by “check” here? Is it something like a boundary check? What if tensorized operators provided by vendor have some alignment requirements (e.g. the start pointer must be multiple of 8/16).

This looks confusing to me because there is no example showing how to schedule the three blocks.

I’m curious if TIR’s Block construct is strong enough to support holistic fusion like Rammer and HFuse, where different Block has heterogeneous workload but we can fuse them in a single kernel. If so we can make such fusion a primitive in TIR schedule and enlarge the auto-scheduling search space.

Thanks, @yzh119. Currently, we have not considered the cross-kernel schedule in TensorIR. But it may be possible if we make it as one large kernel. Could you please show an example? (e.g. the IR before and after the schedule)

Thanks for such a great suggestion. Yes, we do support IRBuilder for TensorIR. However, it is not recommended. Because it is likely to generate illegal or opaque IR (which lacks some of the information). Besides, there are so many attributes/annotations (e.g block read/write regions and block iter_var) to be provided by users if you want to use native IRBuilder.

On the other hand, TVMScript can also represent ANY IR supported in TensorIR and TVM as far as I know. It provides extra syntax checks and sugar to make it easy to write a schedulable IR. So we strongly recommend you try TVMScript if possible. :slight_smile:

Yeah, to fully replace IRBuilder, there are still some missing elements of TVM script: meta programming and hygiene macros. Let’s consider the support after the upstreaming is done :slight_smile:

CC @tqchen