Python is the most popular language for deep learning frameworks due to its great flexibility and rich ecosystem. This RFC plans to utilize a subset of Python AST that can express every TIR node. The new dialect will serve as a way to construct and inspect the IR in Python.
Motivation
ML compilation is an open research area, while its great value has already leads to quick transfer to production. We believe Hybrid Script will enable more ML scientists and engineers to quickly implement prototypes of new ML algorithms, which will increase the rate of innovation. For developers of tvm compiler stack, having a readable and writable text format of IR will ease the difficulty of implementing and testing data structure transformations and optimizations.
Design overview: Round trip IR for any stage of compilation
# opt_gemm.py
# after normalize schedule
class Module:
def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "mmult", "tir.noalias": True})
A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
# body
tir.attr(C_1, "realize_scope", "")
tir.realize(C_1[0:1024, 0:1024])
for x in tir.range(0, 1024):
for y in tir.range(0, 1024):
C_1[x, y] = tir.float32(0)
for k in tir.range(0, 1024):
C_1[x, y] = (C_1[x, y] + (A_1[x, k]*B_1[k, y]))
# after lower
class Module:
def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "mmult", "tir.noalias": True})
A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
# body
for x in tir.range(0, 1024):
for y in tir.range(0, 1024):
C_1.data[((x*1024) + y)] = tir.float32(0)
for k in tir.range(0, 1024):
C_1.data[((x*1024) + y)] = (tir.load("float32", C_1.data, ((x*1024) + y)) + (tir.load("float32", A_1.data, ((x*1024) + k))*tir.load("float32", B_1.data, ((k*1024) + y))))
# host module
class Module:
def mmult(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle) -> ty.int32:
# function attr dict
tir.func_attr({"target": meta[Target][0], "tir.noalias": True, "global_symbol": "mmult", "tir.is_entry_func": True, "calling_conv": 1})
# body
assert (num_args == 3), "mmult: num_args should be 3"
arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0)
arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle")
arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1)
arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle")
arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2)
A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle")
tir.attr(A, "storage_alignment", 128)
arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle")
arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle")
dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32")
B: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle")
tir.attr(B, "storage_alignment", 128)
arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle")
arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle")
C: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle")
tir.attr(C, "storage_alignment", 128)
arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle")
arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle")
assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "mmult: Expect arg[0] to be pointer"
assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "mmult: Expect arg[1] to be pointer"
assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "mmult: Expect arg[2] to be pointer"
assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float32"
assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint"
assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint"
if not (tir.isnullptr(arg0_strides, dtype="bool")):
assert ((1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array"
tir.evaluate(0)
assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint"
assert (1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint"
assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float32"
assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint"
assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint"
if not (tir.isnullptr(arg1_strides, dtype="bool")):
assert ((1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array"
tir.evaluate(0)
assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint"
assert (1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint"
assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint"
assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32"
assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint"
assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint"
if not (tir.isnullptr(arg2_strides, dtype="bool")):
assert ((1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array"
tir.evaluate(0)
assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint"
assert (1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint"
assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint"
tir.attr(0, "compute_scope", "mmult_compute_")
for x in tir.range(0, 1024):
for y in tir.range(0, 1024):
C[((x*1024) + y)] = tir.float32(0)
for k in tir.range(0, 1024):
C[((x*1024) + y)] = tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.load("float32", A, ((x*1024) + k)), tir.load("float32", B, ((k*1024) + y)), tir.load("float32", C, ((x*1024) + y)), dtype="float32")
The basic degsin goal is that printer can print IR as a readable Python script, which parser can parse and construct an Equivalent IR object for any stage of compilation.
The overall design to support all IR variants is to build one-to-one correspodence between scoping structure of IR and tree structure of Python AST.
Ideally, each kind of IRNode corresponds to an AST Node in the Python AST. Generally, for an arbitrary node named xxxNode
, we can use tir.xxx()
in Python script to represent it, and recursively print the function call arguments according to the nodeās constructor function, since we want no information loss in round trip.
scope handlers : For StmtNode with body, we can use with tir.xxx()/for xxx in tir.yyy()
to represent it in Python script, and we will call these nodes and functions.
intrins: The remaining IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins.
In principle, we can represent any IR tree using above two rules, like
with tir.let():
with tir.assert(...):
with tir.assert(...):
with tir.attr(...):
tir.store(...)
tir.evaluate(...)
but it will sacrifice readability and writablity, and some information is not suitable to be printed in such a format.
For example, many nodes like AllocateNode
include a Var
as its constructing parameter. Ideally we want to print a Var
only using its name_hint like
with tir.allocate(packedB, "float32", [16, 16, 16], True):
...
But we will loss the dtype information of packedB
in the above example. Weād better use
packedB = tir.var("handle")
with tir.allocate(packedB, "float32", [16, 16, 16], True):
Now, packedB = tir.var("handle")
doesnāt actually correspond to an IRNode, we call these statements and functions special stmts. The declaration of Buffer
, Var
and DictAttr of PrimFunc are all handled by special stmts.
Design Feature 1: Concise Scoping
Motivation
If we follow the default scoping rule above naively, we will see many unnecessary indents in the text form.
with tir.attr(0, "compute_scope", "mmult_compute_"):
with tir.attr(packedB, "storage_scope", "global"):
with tir.attr(packedB, "storage_alignment", 128):
if tir.isnullptr(packedB, dtype="bool"):
tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
To provide better readability and writablity for Hybrid Script, we provide a rule of consice scoping for printer and parser.
Solution
-
Printer: For a scope handler node, if it is the last stmt of its parent stmtās scope, then we can print its body without explicit indent.
-
Parser: If we encounter a stmt corresponding to a scope handler node but not in
With
context, the rest of current scope are parsed as its body
tir.attr(0, "compute_scope", "mmult_compute_")
tir.attr(packedB, "storage_scope", "global")
tir.attr(packedB, "storage_alignment", 128)
if tir.isnullptr(packedB, dtype="bool"):
tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
Design Feature 2: Extensibility
Motivation
According to our classification of nodes, nodes within a same category usually follow similar printing and parsing rule. With the evolvement of IR, there will be new nodes in the future, it would be convenient if we can write a few lines of code to register it into our printer and parser.
Solution
All the functions (scope handlers, intrins, special stmts) are registered inside class Registry
. Registry
holds diffrent dictionaries mapping function names to actual function entry for diffrent categories of nodes.
The registry mechanism will automatically handle argument parsing and passing with error reporting for missing arguments. The registered function only need to consider how to process the arguments.
According to our classification above,
-
Scope handlers: use
register_scope_handler(scope_name, concise)
to register@register_scope_handler("with_scope", concise=False) def let(parser, node, var, value, body): """ With scope handler function let(var, value, body) """ return tvm.tir.LetStmt(var, value, body)
As said, scope handlers are stmts with body, we further classify them into 2 categories
-
With scope handler (
scope_name="with_scope"
):Since we want to support concise scoping (
concise=True
), some of with scope handlers can appear in two formatswith tir.xxx() # With in Python AST tir.xxx() # Expr(Call) in Python AST
If we ban concise scoping (
concise=True
), then the node can only be represented aswith tir.xxx() # With in Python AST
-
For scope handler (
scope_name="for_scope"
):@register_scope_handler("for_scope") def range(parser, node, begin, end, for_type="serial"): """ For scope handler function range(begin, end, for_type)""" ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) loop_var_name = node.target.id loop_var = tvm.te.var(loop_var_name, dtype="int32") parser.scope_emitter.new_scope() parser.scope_emitter.update_symbol(loop_var_name, loop_var) body = get_body(parser, node) parser.scope_emitter.pop_scope() for_type_dict = {"serial": 0, "parallel": 1, "vectorized": 2, "Unrolled": 3, } return tvm.tir.For(loop_var, begin, extent, for_type_dict[for_type], 0, body)
Their common parsing behavior is
- define a set of loop vars of current scope
- parse body
- return a stmt
-
-
Special stmts: use
register_special_stmt
to registerSpecial stmts can appear in 2 formats, They doesnāt correspond to nodes in IR directly.
-
target = tir.xxx()
, likepackedB = tir.var("handle")
@register_special_stmt def var(parser, node, dtype): return te.var(parser._assign_target, dtype)
-
tir.xxx()
, liketir.func_attr({"global_symbol": "default_function", "tir.noalias": True})
@register_special_stmt def func_attr(parser, node, dict_attr): parser.dict_attr = dict_attr
-
-
Intrin: use
register_intrin
to register@register_intrin def ramp(base, stride, lanes): lanes = lanes.value if not isinstance(lanes, int) else lanes return tvm.tir.Ramp(base, stride, lanes)
Intrin can appear as
tir.xxx()
Please share your comments