[RFC] TensorIR: A schedulable IR for TVM

Background and Motivation

TVM is an end-to-end deep learning compiler with two levels of IR and optimization. TVM translates popular DL frameworks into Relay and optimizes the computation graph, after which it lowers each graph node into Tensor Expression(TE) and does another function-level optimization before finally lowering it into TIR and generating backend code. In brief, the current workflow is

TF/PyTorch/ONNX -> Relay -> TE (based on TE schedule) -> TIR -> C++/CUDA

Currently, low-level optimization is done through TE scheduling, which has several limitations:

  • Based on an accessory data structure: schedule tree. Schedule primitives operate on schedule tree rather than TIR itself directly, which makes the scheduling result less intuitive.
  • All primitives are coupled on the schedule tree representation, it is not easy to add new primitives and ensure the correctness of transformation at the same time.
  • Limited high-dimension instruction and tensorization support. TE is not schedulable after tensorization. The description of tensor intrinsics is not friendly to users.

Introduction

TensorIR is a brand new low-level IR with full scheduling support. Here are some key features and novel techniques.

Core data structure: Block and Block Realize

To generalize the high-dimension tensor expression, we introduce a new statement structure Block. A block can wrap any part of the IR to provide isolation. A block is the minimal unit for scheduling and tensorization, which stores all the fundamental information of computation including block iteration vars, the region that the block reads and writes, the buffer which allocated inside the block, and the critical body statement. Block declare the expression computation with its iteration vars and types.

Workload Example

  1. First of all, we will define the workload(gemm in the example) with hybrid. Note that the hybrid scipt direct generate a TIR program rather than TE stages. (We can auto-complete the loop nesting with block definition by default)
@tvm.hybrid.script
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    C = tir.match_buffer(c, (1024, 1024), "float32")
    A = tir.match_buffer(a, (1024, 1024), "float32")
    B = tir.match_buffer(b, (1024, 1024), "float32")
    reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))

    with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, vj, vk]:
        reducer.step(C[vi, vj], A[vi, vk] * B[vk, vj])
  1. Then create the schedule from TIR and tile the loops using primitives.
s = tir.create_schedule(matmul)
update = s.get_block("C")
i, j, k = s.get_axes(update)
i_o, i_i = s.split(i, bn)
j_o, j_i = s.split(j, bn)
k_o, k_i = s.split(k, 4)
s.reorder(i_o, j_o, k_o, k_i, i_i, j_i)

Result

def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # function attr dict
    B = tir.match_buffer(b, [1024, 1024])
    A = tir.match_buffer(a, [1024, 1024])
    C = tir.match_buffer(c, [1024, 1024])
    reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))
    # body
    for i0_outer, i1_outer, i2_outer, i2_inner, i0_inner, i1_inner in tir.grid(32, 32, 256, 4, 32, 32):
         with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, vj, vk]:
             tir.bind(vi, ((i0_outer*32) + i0_inner))
             tir.bind(vj, ((i1_outer*32) + i1_inner))
             tir.bind(vk, ((i2_outer*4) + i2_inner))
             reducer.step(C[vi, vj], (A[vi, vk]*B[vk, vj]))
  1. Vectorize and decompose reduction
s.vectorize(j_i)
s.decompose_reduction(update, j_o)

Result

@tvm.hybrid.script
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # function attr dict
    B = tir.match_buffer(b, [1024, 1024])
    A = tir.match_buffer(a, [1024, 1024])
    C = tir.match_buffer(c, [1024, 1024])
    # body
    for i0_outer in, i1_outer_init, i0_inner_init in tir.grid(32, 32, 32):
        for i1_inner_init in range(0, 32, annotation = {"loop_type":"vectorize"}):
            with tir.block([1024, 1024], "C_init") as [vi_init, vj_init]:
                tir.bind(vi_init, ((i0_outer*32) + i0_inner_init))
                tir.bind(vj_init, ((i1_outer_init*32) + i1_inner_init))
                C[vi_init, vj_init] = tir.float32(0)
        for i1_outer, i2_outer, i2_inner, i0_inner in tir.grid(32, 256, 4, 32):
            for i1_inner in range(0, 32, annotation = {"loop_type":"vectorize"}):
                with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C_update") as [vi, vj, vk]:
                    tir.bind(vi, ((i0_outer*32) + i0_inner))
                    tir.bind(vj, ((i1_outer*32) + i1_inner))
                    tir.bind(vk, ((i2_outer*4) + i2_inner))
                    C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vk, vj]))
  1. Print it and run
build_func = tvm.build(s.func, target=target)
build_func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
evaluator = build_func.time_evaluator(build_func.entry_name, ctx, number=1)

In this example, we are imperatively changing the IR rather than waiting until the end to change(TE). It is very important for users and developers to directly see what happens during the scheduling. Also, at every stage during the schedule, we can get a verifiable IR, that’s the major improvement for both user experience and correctness proof.

Key Features

Independent scheduling based on IR itself

Different from the TE schedule, TensorIR has a complete set of schedule algorithms, which does not need a schedule tree or any extra data structure. We will introduce a brand new set of schedule primitives and it has full backward compatibility for the TE schedule. We simplify the compiling workload and conception.

TF/PyTorch/ONNX -> Relay -> TIR -> schedule -> TIR -> scheudle -> TIR -> C++/CUDA

Now, there is no stage during the schedule. Rather than lowering the schedule into TIR, we directly mutate the TIR itself. Also, it enables the sequential schedule (schedule several times for a single workload).

Stronger Expressiveness and Optimization Ability

TE has limited expressiveness since each stage is defined by Stage = te.compute(lambda expr), while TensorIR is a full c+±like IR. We can write any program with TensorIR as you want. Although not all programs can be scheduled, there are still more workloads that can be optimized by TensorIR.

One of the improved tasks is concatenating:

TE

B = te.compute(i, tvm.tir.if_then_else(i < 10, A0[i], tvm.tir.if_then_else(i < 20, A1[i - 10], A2[i - 20])

TensorIR:

with tir.block([10]) as vi:
    B[vi] = A0[vi]
with tir.block([10]) as vi:
    B[vi + 10] = A1[vi]
with tir.block([10]) as vi:
    B[vi + 20] = A2[vi]

The critical improvement is performance. In TIR we optimize the program by deprecating the if branch, which is impossible in the TE schedule.

Memory and Execution Scope

Hardware accelerators led by GPUs are increasingly using hierarchical architecture, including memory hierarchy(global, shared, local/wmma in NV-GPU) and execution hierarchy(SM, warp, thread in NV-GPU). TVM defines the memory hierarchy and TensorIR provides corresponding native hierarchical block and execution scope. Different execution scope can access different memory scope.

TensorIR natively supports hierarchy checks. We will check the memory access and thread binding, including warp level instruction(wmma) validation during the schedule. Following is an example of the GPU hierarchy.

for bx in range(0, 32, annotation = {"loop_type":"blockIdx.x"}):
    with block(exec_scope="gpu_block"):
        for ty in range(0, 32, annotation = {"loop_type":"threadIdx.y"}):
            with block(exec_scope="gpu_warp"):
                for ty in range(0, 32, annotation = {"loop_type":"threadIdx.x"}):
                    with block(exec_scope="gpu_thread"):
                        A[i] = B[i] + 1

High-dimension Scheduling and Tensorization.

With more and more backend provides tensor operator and instruction, the limitation of the TE schedule shows off. It is hard to tensorize a complex stage. TensorIR chooses a block (a wrap of high-dimension computation) as the minimum schedule unit. We can natively tensorize a sub-program and even do schedule after tensorization.

Decouple Primitives and Correctness Proof

Since every primitive directly rewrite the AST (it works like a StmtMutator), we easily decouple primitives. Developers can add a schedule primitive as easy as add a pass. Also, it is easy to ensure the program has the same behavior between origin one and scheduled one by proving the correctness of each primitive and checking intermedia IR during the schedule.

Round-trip Python Hybrid Syntax (Already upstream)

Hybrid enables developers to write TensorIR through python syntax, one of the most popular languages. Also, it provides a way to store the IR after or during the schedule. Please see the detail by [RFC] Hybrid Script Support for TIR

Migration Plan

  1. Upstream TensorIR data structure.
  2. Upstream TensorIR schedule primitive
  3. Support AutoTVM/Ansor on TensorIR

Co-author @spectrometerHBH @tqchen @junrushao

I appreciate discuss and advise from @merrymercy

36 Likes

It’s really a great work! @Hzfengsy

  1. Have you tried tensorize intrinsic(e.g. TensorCore schedule) using this new IR? Since I remember that to support tensorize is also one of your initial motivations.

I’m curious if this would requires a different programming logic writing such schedule. Before seeing any example, I feel hard to imagine how the schedule would be.

  1. The if_then_else case is really a strong point!

The original implementation of te.compute makes spatial iterators to follow the shape of the output Tensor, which limit us to write more flexible computations using TVM.

But seems this would break the original design of Stage, or do we still have a concept of Stage in the new TensorIR? How to schedule them?

  1. Looking forward to upgrade Ansor based on the new TensorIR.

Thanks!

1 Like

Great work! I believe this will make “scheduling” more flexible and intuitive!

However, will this increase the coupling between the schedule and the lower pass, which may lead to an increase in the complexity of the lower pass?

By the way, I’m also looking forward to know how to auto-schedule based on the the TensorIR.

Thanks!

Thanks for your reply! @jcf94

A1. We’ve tried to tensorize intrinsic using this new IR, and are working on the TensorCore demo. Our design is really close to the original tensorize programming logic, only differs in the declaration of description&implementation of HW intrinsic (we can use Hybrid Script to program them more conveniently)

A2. There is no concept of Stage in TensorIR. What you can see in Hybrid Script contains full IR information and our schedule transforms IR directly. The basic schedule unit is Block.

A3. Ansor + TensorIR is on our roadmap.

1 Like

Thanks for your reply! @ds1231h

At the moment, we at first transform TIR with block to TIR without block in one pass(the latter is current TIR in the main repo) and utilizes the current TIR lowering passes to do codegen. And we can gradually refactor this lowering process in the future.

Hi,

Even though I don’t think I understood everything, I like the idea of solving some of the limitations of te.compute. Since the te.compute is in a central part of the TVM stack changing it requires a lot of work and understanding. So thank you all for continuing such development.

Q1: I was wondering how this fits in the Relay ->Topi->TE->TIR flow. In a more specific case take FuseOps. AFAIK the FuseOps pass creates the multi-stage operator based on the te.computes. Since you mentioned that there would be no notion of a stage in the new TensorIR, how would FuseOps work? and more generally how TVM’s philosophy of “defining a compute rule and a separate schedule” be changed?

Q2: What exactly do you mean by “not all program can be scheduled”? maybe an example?

Q3: You mentioned “new scheduling primitives”, could you maybe give a list?

EDIT: Q4: Expected timeline of the steps?

1 Like

Well-received with thanks!

Thank you for your interest.

A1: Current op fusing is based on stage but the critical point is fusing the injective computation. We can also inline injective computation by traverse_inline. So there is no doubt that FuseOps works. As for the philosophy, I think there are only few changes. TIR is not only an IR but only can be a computation declaration. We provide very user-friendly API (as easy as TE) to define compute rules.

A2: TIR is a general IR which can represent almost every program. It’s really hard to schedule a general program, but we promise TIR can schedule all programs which TE can schedule.

A3: Most of the primitives are similar to the TE ones. For now, only two new primitives are decompose_reduction and merge_reduction

A4: Upstream on Oct and Nov. Ansor supporting is WIP

I hope I can answer your question.

Thank you for this proposal! This work does make scheduling much easier. I have a concern about using this way to write a tensor expression. It looks like more complicated than tvm.compute when defining matmul. We need to define some buffers and creating block with corresponding shape dimension. It would be helpful if you can add a conv2d example which can replace existing topi.nn.conv2d definition to better understand what developer would need to write.

Another question is about representing generic programming style ops such as shape functions. Since these programs don’t fit into tvm scheduling, I assume it would still be more convenient to use existing te hybrid script to create these ops?

2 Likes

Thanks for your reply! @kevinthesun

In original te programming, we also have to declare buffers and create lambda expression with iter_vars having the correct shape dimension. If we take this additional info into account, TE programming is close to Hybrid Script’s programming in complexity.

Currently, we can not replace existing topi operators, since they are represented by Stage/Op and optimized by te schedule, while Hybrid Script will be parsed into TIR directly.

If we don’t have to schedule the PrimFunc, we don’t have to declare blocks in TIR. Actually, TE hybrid script is also a text representation of TIR to a large extent, with loop and condition statement directly representing the IR structure, using sugar like variable being translated to Array of size 1 to ease the format of Store&Load. At the moment, we haven’t introduced sugars to simplify such Load&Store, but the rest writing are largely simlilar.

TIR and TE do not conflict with each other. TE is still a useful DSL to stitch fragments of TIR together to form a PrimFunc.

We could still define TE based DSL(backed by TIR) that enables primitives like compute and hybrid calls to stitch together a dataflow graph to form a PrimFunc And then use the TIR for scheduling.

5 Likes

Thanks for the proposal! This definitely opens more opportunities for performance optimization. Two questions for clarification:

  1. IIUC, based on the proposal and discussion, we will have both TE and TIR, but TE is more like a frontend wrapper of TIR to serve some users that prefer to write high-level DSL. Then, what will we do with the TE schedule primitives? Intuitively, we should still keep them; otherwise TE writers will have no way to schedule their computes, because they know nothing about TIR and blocks.

  2. Does this proposal support dynamic shape (i.e., Any)? For example, can we have something like:

    @tvm.hybrid.script
    def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
        C = tir.match_buffer(c, (1024, 1024), "float32")
        A = tir.match_buffer(a, (1024, Any), "float32")
        B = tir.match_buffer(b, (Any, 1024), "float32")
        reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))
    
        with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, vj, vk]:
            reducer.step(C[vi, vj], A[vi, vk] * B[vk, vj])
    
    s = tir.create_schedule(matmul)
    update = s.get_block("C")
    i, j, k = s.get_axes(update)
    i_o, i_i = s.split(i, bn)
    j_o, j_i = s.split(j, bn)
    k_o, k_i = s.split(k, 4)
    

    In this case, the length of vk (or k) is Any. Can we still apply split to it with a fixed factor

1 Like

Thanks for explanation. The relation between te and new tir is now more clear to me.

Thanks for clarification. It would be nice if we can use various methods to create tensor programs and use new tir to schedule them.

Good questions!

  1. As for as we know, we would like to let users use TensorIR schedule rather than TE schedule one we fully upstream the TensorIR. For three reasons:

    1. Just as you have mentioned, TE is a fronted wrapper, and it directly generates TIR with blocks. Somehow, TE is more like a sugar to define TIR.
    2. Most of the schedules and primitives in TensorIR are very similar to those in TE. The cost of learning TensorIR schedule is extremely low (maybe just one day).
    3. All primitives are based on the block (no stage concept in TensorIR schedule). It’s hard to keep the TE schedule with block
  2. Dynamic shapes are not supported now. However, thanks to our de-coupled primitives, it’s ok to support later.

Would love to see dynamic shape supported otherwise a large set of models can’t be backed by new TensorIR. :smiley:

3 Likes

So the scenario is like you can choose to use TE or TIR to write a compute, but if you choose TE, you have to first lower it to TIR and then add schedule primitives?

IIUC, it seems to me that this is nontrivial, because TIR was not written by human and you may need to first print it out to figure out how to schedule it. It sounds more straightforward to keep TE schedule as syntactic sugar. At least you can get the sense about how the schedule looks like by tracing the Python code.

3 Likes

Because there is a 1-1 mapping between te.Stage and Block. It should actually not be hard to use tir schedule to schedule a te compute generated PrimFunc (either by getting block via name, or pragmatically traverse the blocks like we do pragmatically on stages). But i agree that we can keep te.schedule for a bit.

Thanks for clarification. Make sense to me.

Thanks for the proposal! Just courious about the schuedule primitives like cache_write and cache_read, since there are no stages in TensorIR.