Background and Motivation
TVM is an end-to-end deep learning compiler with two levels of IR and optimization. TVM translates popular DL frameworks into Relay and optimizes the computation graph, after which it lowers each graph node into Tensor Expression(TE) and does another function-level optimization before finally lowering it into TIR and generating backend code. In brief, the current workflow is
TF/PyTorch/ONNX -> Relay -> TE (based on TE schedule) -> TIR -> C++/CUDA
Currently, low-level optimization is done through TE scheduling, which has several limitations:
- Based on an accessory data structure: schedule tree. Schedule primitives operate on schedule tree rather than TIR itself directly, which makes the scheduling result less intuitive.
- All primitives are coupled on the schedule tree representation, it is not easy to add new primitives and ensure the correctness of transformation at the same time.
- Limited high-dimension instruction and tensorization support. TE is not schedulable after tensorization. The description of tensor intrinsics is not friendly to users.
Introduction
TensorIR is a brand new low-level IR with full scheduling support. Here are some key features and novel techniques.
Core data structure: Block and Block Realize
To generalize the high-dimension tensor expression, we introduce a new statement structure Block. A block can wrap any part of the IR to provide isolation. A block is the minimal unit for scheduling and tensorization, which stores all the fundamental information of computation including block iteration vars, the region that the block reads and writes, the buffer which allocated inside the block, and the critical body statement. Block declare the expression computation with its iteration vars and types.
Workload Example
- First of all, we will define the workload(gemm in the example) with hybrid. Note that the hybrid scipt direct generate a TIR program rather than TE stages. (We can auto-complete the loop nesting with block definition by default)
@tvm.hybrid.script
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
C = tir.match_buffer(c, (1024, 1024), "float32")
A = tir.match_buffer(a, (1024, 1024), "float32")
B = tir.match_buffer(b, (1024, 1024), "float32")
reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, vj, vk]:
reducer.step(C[vi, vj], A[vi, vk] * B[vk, vj])
- Then create the schedule from TIR and tile the loops using primitives.
s = tir.create_schedule(matmul)
update = s.get_block("C")
i, j, k = s.get_axes(update)
i_o, i_i = s.split(i, bn)
j_o, j_i = s.split(j, bn)
k_o, k_i = s.split(k, 4)
s.reorder(i_o, j_o, k_o, k_i, i_i, j_i)
Result
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# function attr dict
B = tir.match_buffer(b, [1024, 1024])
A = tir.match_buffer(a, [1024, 1024])
C = tir.match_buffer(c, [1024, 1024])
reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))
# body
for i0_outer, i1_outer, i2_outer, i2_inner, i0_inner, i1_inner in tir.grid(32, 32, 256, 4, 32, 32):
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, vj, vk]:
tir.bind(vi, ((i0_outer*32) + i0_inner))
tir.bind(vj, ((i1_outer*32) + i1_inner))
tir.bind(vk, ((i2_outer*4) + i2_inner))
reducer.step(C[vi, vj], (A[vi, vk]*B[vk, vj]))
- Vectorize and decompose reduction
s.vectorize(j_i)
s.decompose_reduction(update, j_o)
Result
@tvm.hybrid.script
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# function attr dict
B = tir.match_buffer(b, [1024, 1024])
A = tir.match_buffer(a, [1024, 1024])
C = tir.match_buffer(c, [1024, 1024])
# body
for i0_outer in, i1_outer_init, i0_inner_init in tir.grid(32, 32, 32):
for i1_inner_init in range(0, 32, annotation = {"loop_type":"vectorize"}):
with tir.block([1024, 1024], "C_init") as [vi_init, vj_init]:
tir.bind(vi_init, ((i0_outer*32) + i0_inner_init))
tir.bind(vj_init, ((i1_outer_init*32) + i1_inner_init))
C[vi_init, vj_init] = tir.float32(0)
for i1_outer, i2_outer, i2_inner, i0_inner in tir.grid(32, 256, 4, 32):
for i1_inner in range(0, 32, annotation = {"loop_type":"vectorize"}):
with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C_update") as [vi, vj, vk]:
tir.bind(vi, ((i0_outer*32) + i0_inner))
tir.bind(vj, ((i1_outer*32) + i1_inner))
tir.bind(vk, ((i2_outer*4) + i2_inner))
C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vk, vj]))
- Print it and run
build_func = tvm.build(s.func, target=target)
build_func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), rtol=1e-5)
evaluator = build_func.time_evaluator(build_func.entry_name, ctx, number=1)
In this example, we are imperatively changing the IR rather than waiting until the end to change(TE). It is very important for users and developers to directly see what happens during the scheduling. Also, at every stage during the schedule, we can get a verifiable IR, that’s the major improvement for both user experience and correctness proof.
Key Features
Independent scheduling based on IR itself
Different from the TE schedule, TensorIR has a complete set of schedule algorithms, which does not need a schedule tree or any extra data structure. We will introduce a brand new set of schedule primitives and it has full backward compatibility for the TE schedule. We simplify the compiling workload and conception.
TF/PyTorch/ONNX -> Relay -> TIR -> schedule -> TIR -> scheudle -> TIR -> C++/CUDA
Now, there is no stage during the schedule. Rather than lowering the schedule into TIR, we directly mutate the TIR itself. Also, it enables the sequential schedule (schedule several times for a single workload).
Stronger Expressiveness and Optimization Ability
TE has limited expressiveness since each stage is defined by Stage = te.compute(lambda expr)
, while TensorIR is a full c+±like IR. We can write any program with TensorIR as you want. Although not all programs can be scheduled, there are still more workloads that can be optimized by TensorIR.
One of the improved tasks is concatenating:
TE
B = te.compute(i, tvm.tir.if_then_else(i < 10, A0[i], tvm.tir.if_then_else(i < 20, A1[i - 10], A2[i - 20])
TensorIR:
with tir.block([10]) as vi:
B[vi] = A0[vi]
with tir.block([10]) as vi:
B[vi + 10] = A1[vi]
with tir.block([10]) as vi:
B[vi + 20] = A2[vi]
The critical improvement is performance. In TIR we optimize the program by deprecating the if
branch, which is impossible in the TE schedule.
Memory and Execution Scope
Hardware accelerators led by GPUs are increasingly using hierarchical architecture, including memory hierarchy(global, shared, local/wmma in NV-GPU) and execution hierarchy(SM, warp, thread in NV-GPU). TVM defines the memory hierarchy and TensorIR provides corresponding native hierarchical block and execution scope. Different execution scope can access different memory scope.
TensorIR natively supports hierarchy checks. We will check the memory access and thread binding, including warp level instruction(wmma) validation during the schedule. Following is an example of the GPU hierarchy.
for bx in range(0, 32, annotation = {"loop_type":"blockIdx.x"}):
with block(exec_scope="gpu_block"):
for ty in range(0, 32, annotation = {"loop_type":"threadIdx.y"}):
with block(exec_scope="gpu_warp"):
for ty in range(0, 32, annotation = {"loop_type":"threadIdx.x"}):
with block(exec_scope="gpu_thread"):
A[i] = B[i] + 1
High-dimension Scheduling and Tensorization.
With more and more backend provides tensor operator and instruction, the limitation of the TE schedule shows off. It is hard to tensorize a complex stage. TensorIR chooses a block (a wrap of high-dimension computation) as the minimum schedule unit. We can natively tensorize a sub-program and even do schedule after tensorization.
Decouple Primitives and Correctness Proof
Since every primitive directly rewrite the AST (it works like a StmtMutator
), we easily decouple primitives. Developers can add a schedule primitive as easy as add a pass. Also, it is easy to ensure the program has the same behavior between origin one and scheduled one by proving the correctness of each primitive and checking intermedia IR during the schedule.
Round-trip Python Hybrid Syntax (Already upstream)
Hybrid enables developers to write TensorIR through python syntax, one of the most popular languages. Also, it provides a way to store the IR after or during the schedule. Please see the detail by [RFC] Hybrid Script Support for TIR
Migration Plan
- Upstream TensorIR data structure.
- Upstream TensorIR schedule primitive
- Support AutoTVM/Ansor on TensorIR
Co-author @spectrometerHBH @tqchen @junrushao
I appreciate discuss and advise from @merrymercy