(F1) Template Metaprogramming
Users should be able to use variables from outer scope in the TVMScript
function/class. The parsed result should be identical to function/class with
the variable replaced by its value. For instance,
@T.prim_func
def matmul(
A: T.Buffer[(128, 128)],
) -> None:
...
def gen_matmul(n, m) -> None:
@T.prim_func
def f(A: T.Buffer[(n, m)]):
...
return f
f = matmul(n=128, m=128) # `f` should be identical to `matmul`
This is already partially supported by https://github.com/apache/tvm/pull/11097
for using PrimExpr
captured by outer function. With the new parser, we want
to support this feature in more places and with more variable types.
(F2) Rank-polymorphism
Users should be able to write a single function to handle different ranks of
input buffers (different numbers of dimensions). For example, user should be
able to write a generic function to do broadcast add,
def broadcast_add(a, b, c):
@T.prim_func
def f(
A: T.BufferFrom(a),
B: T.BufferFrom(b),
C: T.BufferFrom(c),
) -> None:
for i, i_a, i_b in T.some_broadcast_method(A.shape, B.shape):
with T.block():
C[*i] = A[*i_a] + B[*i_b]
broadcast_add(
a = Buffer((128, 1), "float32"),
b = Buffer((1, 128), "float32"),
c = Buffer((128, 128), "float32"),
)
(F3) Sugar: TE Compute in TIR
Users should be able to replace boilerplate code with a function call, which’s
expanded to large chunk of code during parsing. For example, we may want to use
TE’s compute-like syntax to replace nested loop,
@T.prim_func
def te_compute_sugar(
A: T.Buffer[(128, 128)],
B: T.Buffer[(128, 128)],
) -> None:
...
C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
...
## expands to ====>
@T.prim_func
def te_compute_expanded(
A: T.Buffer[(128, 128)],
B: T.Buffer[(128, 128)],
) -> None:
...
for i in range(128):
for j in range(128):
with T.block("..."):
C[i, j] = A[i, j] + B[i, j]
...
(F4) Interleave host program and TVMScript program to customize metaprogramming
As an escape hatch from writing code to be parsed (or evaluated) by TVMScript
parser, users should be able to write imperative code to construct IR nodes
directly and embed it inside regular TVMScript. This gives users the ultimate
tool when TVMScript isn’t expressible enough for their use cases. For example,
at
python/tvm/topi/vision/nms.py#L380-L431,
there are blocks of repetitive code on computing the coordinates of the four
corners of bounding box. This can be simplified as:
# Before, without IRBuilder interleaving
@T.prim_func
def nms(...):
...
for i in range(batch_size):
...
a_l = min(
output[batch_idx, box_a_idx, box_start_idx],
output[batch_idx, box_a_idx, box_start_idx + 2],
)
a_t = min(
output[batch_idx, box_a_idx, box_start_idx + 1],
output[batch_idx, box_a_idx, box_start_idx + 3],
)
a_r = max(
output[batch_idx, box_a_idx, box_start_idx],
output[batch_idx, box_a_idx, box_start_idx + 2],
)
a_b = max(
output[batch_idx, box_a_idx, box_start_idx + 1],
output[batch_idx, box_a_idx, box_start_idx + 3],
)
...
for k in range(j):
check_iou = ...
...
if check_iou > 0:
# b_l: left, b_t: top, b_r: right, b_b: bottom
b_l = min(
output[batch_idx, box_b_idx, box_start_idx],
output[batch_idx, box_b_idx, box_start_idx + 2],
)
b_t = min(
output[batch_idx, box_b_idx, box_start_idx + 1],
output[batch_idx, box_b_idx, box_start_idx + 3],
)
b_r = max(
output[batch_idx, box_b_idx, box_start_idx],
output[batch_idx, box_b_idx, box_start_idx + 2],
)
b_b = max(
output[batch_idx, box_b_idx, box_start_idx + 1],
output[batch_idx, box_b_idx, box_start_idx + 3],
)
...
# With IRBuilder interleaving:
from tvm.script import tir as T
def get_box_coordinates(output, batch_idx, box_idx, box_start_idx):
"""a method executed by python interpreter"""
box_l = T.min(
output[batch_idx, box_idx, box_start_idx],
output[batch_idx, box_idx, box_start_idx + 2],
) # type(box_l) is PrimExpr
... # Repeat for other coordinates
return box_l, box_t, box_r, box_b
@T.prim_func(capture=[get_box_coordinates])
def nms(...):
...
for i in range(batch_size):
...
a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)
...
for k in range(j):
check_iou = ...
...
if check_iou > 0:
b_l, b_t, b_r, b_b = get_box_coordinates(output, batch_idx, box_b_idx, box_start_idx)
...