[RFC] Hybrid Script Support for TIR

Python is the most popular language for deep learning frameworks due to its great flexibility and rich ecosystem. This RFC plans to utilize a subset of Python AST that can express every TIR node. The new dialect will serve as a way to construct and inspect the IR in Python.

Motivation

ML compilation is an open research area, while its great value has already leads to quick transfer to production. We believe Hybrid Script will enable more ML scientists and engineers to quickly implement prototypes of new ML algorithms, which will increase the rate of innovation. For developers of tvm compiler stack, having a readable and writable text format of IR will ease the difficulty of implementing and testing data structure transformations and optimizations.

Design overview: Round trip IR for any stage of compilation

# opt_gemm.py
# after normalize schedule
class Module:
    def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mmult", "tir.noalias": True})
        A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        # body
        tir.attr(C_1, "realize_scope", "")
        tir.realize(C_1[0:1024, 0:1024])
        for x in tir.range(0, 1024):
            for y in tir.range(0, 1024):
                C_1[x, y] = tir.float32(0)
                for k in tir.range(0, 1024):
                    C_1[x, y] = (C_1[x, y] + (A_1[x, k]*B_1[k, y]))
# after lower
class Module:
    def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mmult", "tir.noalias": True})
        A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        # body
        for x in tir.range(0, 1024):
            for y in tir.range(0, 1024):
                C_1.data[((x*1024) + y)] = tir.float32(0)
                for k in tir.range(0, 1024):
                    C_1.data[((x*1024) + y)] = (tir.load("float32", C_1.data, ((x*1024) + y)) + (tir.load("float32", A_1.data, ((x*1024) + k))*tir.load("float32", B_1.data, ((k*1024) + y))))
                    
# host module
class Module:
    def mmult(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle) -> ty.int32:
        # function attr dict
        tir.func_attr({"target": meta[Target][0], "tir.noalias": True, "global_symbol": "mmult", "tir.is_entry_func": True, "calling_conv": 1})
        # body
        assert (num_args == 3), "mmult: num_args should be 3"
        arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
        arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0)
        arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle")
        arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1)
        arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle")
        arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2)
        A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle")
        tir.attr(A, "storage_alignment", 128)
        arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle")
        arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle")
        dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32")
        B: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle")
        tir.attr(B, "storage_alignment", 128)
        arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle")
        arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle")
        C: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle")
        tir.attr(C, "storage_alignment", 128)
        arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle")
        arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle")
        assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "mmult: Expect arg[0] to be pointer"
        assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "mmult: Expect arg[1] to be pointer"
        assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "mmult: Expect arg[2] to be pointer"
        assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
        assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
        assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float32"
        assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint"
        assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint"
        if not (tir.isnullptr(arg0_strides, dtype="bool")):
            assert ((1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array"
            tir.evaluate(0)
        assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint"
        assert (1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint"
        assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
        assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
        assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float32"
        assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint"
        assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint"
        if not (tir.isnullptr(arg1_strides, dtype="bool")):
            assert ((1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array"
            tir.evaluate(0)
        assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint"
        assert (1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint"
        assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint"
        assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
        assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
        assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32"
        assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint"
        assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint"
        if not (tir.isnullptr(arg2_strides, dtype="bool")):
            assert ((1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array"
            tir.evaluate(0)
        assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint"
        assert (1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint"
        assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint"
        tir.attr(0, "compute_scope", "mmult_compute_")
        for x in tir.range(0, 1024):
            for y in tir.range(0, 1024):
                C[((x*1024) + y)] = tir.float32(0)
                for k in tir.range(0, 1024):
                    C[((x*1024) + y)] = tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.load("float32", A, ((x*1024) + k)), tir.load("float32", B, ((k*1024) + y)), tir.load("float32", C, ((x*1024) + y)), dtype="float32")

The basic degsin goal is that printer can print IR as a readable Python script, which parser can parse and construct an Equivalent IR object for any stage of compilation.

The overall design to support all IR variants is to build one-to-one correspodence between scoping structure of IR and tree structure of Python AST.

Ideally, each kind of IRNode corresponds to an AST Node in the Python AST. Generally, for an arbitrary node named xxxNode, we can use tir.xxx() in Python script to represent it, and recursively print the function call arguments according to the nodeā€™s constructor function, since we want no information loss in round trip.

scope handlers : For StmtNode with body, we can use with tir.xxx()/for xxx in tir.yyy() to represent it in Python script, and we will call these nodes and functions.

intrins: The remaining IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins.

In principle, we can represent any IR tree using above two rules, like

with tir.let():
    with tir.assert(...):
        with tir.assert(...):
	          with tir.attr(...):
      	        tir.store(...)
      		      tir.evaluate(...)        

but it will sacrifice readability and writablity, and some information is not suitable to be printed in such a format.

For example, many nodes like AllocateNode include a Var as its constructing parameter. Ideally we want to print a Var only using its name_hint like

with tir.allocate(packedB, "float32", [16, 16, 16], True):
   ...

But we will loss the dtype information of packedB in the above example. Weā€™d better use

packedB = tir.var("handle")
with tir.allocate(packedB, "float32", [16, 16, 16], True):

Now, packedB = tir.var("handle") doesnā€™t actually correspond to an IRNode, we call these statements and functions special stmts. The declaration of Buffer , Var and DictAttr of PrimFunc are all handled by special stmts.

Design Feature 1: Concise Scoping

Motivation

If we follow the default scoping rule above naively, we will see many unnecessary indents in the text form.

with tir.attr(0, "compute_scope", "mmult_compute_"):
    with tir.attr(packedB, "storage_scope", "global"):
        with tir.attr(packedB, "storage_alignment", 128):
            if tir.isnullptr(packedB, dtype="bool"):
                tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))

To provide better readability and writablity for Hybrid Script, we provide a rule of consice scoping for printer and parser.

Solution

  • Printer: For a scope handler node, if it is the last stmt of its parent stmtā€™s scope, then we can print its body without explicit indent.

  • Parser: If we encounter a stmt corresponding to a scope handler node but not in With context, the rest of current scope are parsed as its body

tir.attr(0, "compute_scope", "mmult_compute_")
tir.attr(packedB, "storage_scope", "global")
tir.attr(packedB, "storage_alignment", 128)
if tir.isnullptr(packedB, dtype="bool"):
    tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))

Design Feature 2: Extensibility

Motivation

According to our classification of nodes, nodes within a same category usually follow similar printing and parsing rule. With the evolvement of IR, there will be new nodes in the future, it would be convenient if we can write a few lines of code to register it into our printer and parser.

Solution

All the functions (scope handlers, intrins, special stmts) are registered inside class Registry. Registry holds diffrent dictionaries mapping function names to actual function entry for diffrent categories of nodes.

The registry mechanism will automatically handle argument parsing and passing with error reporting for missing arguments. The registered function only need to consider how to process the arguments.

According to our classification above,

  • Scope handlers: use register_scope_handler(scope_name, concise) to register

    @register_scope_handler("with_scope", concise=False)
    def let(parser, node, var, value, body):
    """ With scope handler function let(var, value, body) """
    
    return tvm.tir.LetStmt(var, value, body)
    

    As said, scope handlers are stmts with body, we further classify them into 2 categories

    • With scope handler (scope_name="with_scope"):

      Since we want to support concise scoping (concise=True), some of with scope handlers can appear in two formats

      with tir.xxx() # With in Python AST
      tir.xxx()      # Expr(Call) in Python AST
      

      If we ban concise scoping (concise=True), then the node can only be represented as

      with tir.xxx() # With in Python AST
      
    • For scope handler (scope_name="for_scope"):

      @register_scope_handler("for_scope")
      def range(parser, node, begin, end, for_type="serial"):
          """ For scope handler function range(begin, end, for_type)"""
          ana = tvm.arith.Analyzer()
          extent = end if begin == 0 else ana.simplify(end - begin)
          loop_var_name = node.target.id
          loop_var = tvm.te.var(loop_var_name, dtype="int32")
      
          parser.scope_emitter.new_scope()
          parser.scope_emitter.update_symbol(loop_var_name, loop_var)
          body = get_body(parser, node)
          parser.scope_emitter.pop_scope()
      
          for_type_dict = {"serial": 0, "parallel": 1, "vectorized": 2, "Unrolled": 3, }
          return tvm.tir.For(loop_var, begin, extent, for_type_dict[for_type], 0, body)
      

      Their common parsing behavior is

      1. define a set of loop vars of current scope
      2. parse body
      3. return a stmt
  • Special stmts: use register_special_stmt to register

    Special stmts can appear in 2 formats, They doesnā€™t correspond to nodes in IR directly.

    • target = tir.xxx(), like packedB = tir.var("handle")

      @register_special_stmt
      def var(parser, node, dtype):
      return te.var(parser._assign_target, dtype)
      
    • tir.xxx(), like tir.func_attr({"global_symbol": "default_function", "tir.noalias": True})

      @register_special_stmt
      def func_attr(parser, node, dict_attr):
          parser.dict_attr = dict_attr
      
  • Intrin: use register_intrin to register

    @register_intrin
    def ramp(base, stride, lanes):
    lanes = lanes.value if not isinstance(lanes, int) else lanes
    return tvm.tir.Ramp(base, stride, lanes)
    

    Intrin can appear as tir.xxx()

Please share your comments :slight_smile:

13 Likes

Thanks for the proposal, @spectrometerHBH! Dropping a few thoughts and questions here.

If I understand correctly, the proposal is to use Python as a text format, meaning we not only convert Python into TIR, but we also convert TIR into Python. Could you motivate the second part? TIR already prints a C-like text format, so users can inspect the IR.

If we donā€™t need the ability to convert from TIR into Python, then we would no longer need to maintain a one-to-one correspondence between the scoping structure of the IR and the tree structure of the Python AST. This way, we could use a decorator approach and could add more syntactic sugar thatā€™s geared towards productivity.

It would be interesting to hear @jroeschā€™s thoughts, as Iā€™m aware that heā€™s pushing on a similar effort to make it easier to generate Relay implementations in Python. It might make sense to codesign these two interfaces.

I believe @spectrometerHBH has already implemented both the parser and text printer with strict 1-1 correspondence. I am interested to see if we can do the same thing for relay, which will be super useful for developers IMO.

1 Like

Thank you for your comment, @weberlo!

As for the motivation for converting TIR into Python, one scenario I can come up with is that developers can dump TIR into Python script during the middle phase of lowering, do some modifications on the script by hand, and then parse the script into TIR to finish the rest of lowering phases, which I think may help debug and test.

Also, Iā€™m not clear about your claim that we would no longer need to maintain a one-to-one correspondence and we could use a decorator approach. Can you please show some examples?

2 Likes

I also think TIR -> Python would be a good idea.

Especially if we want to debug some aspects of schedules, being able to generate the Python code and using the Python debugger to set corresponding break points would simplify the process.

I am not aware of an ā€œeasyā€ solution to debug schedules from the Python driven compilation scripts/process.

1 Like

@spectrometerHBH

Would we need to add hooks to allow users to edit TIR in the middle of the pipeline? Iā€™m not sure exactly what this process looks like, though maybe it will become clearer once there are tests in the accompanying PR showing use cases.

I think this was me misunderstanding Hybrid Script, because it does look like it already uses a Python decorator. Even so, I was saying that, hypothetically, if we didnā€™t need to maintain a TIR -> Python conversion, we could add higher-level constructs to Hybrid Script that would then be elaborated into lower-level TIR constructs.

Another question I have after reading more is when would we want to ban concise scoping? To me, it seems like we would always want it.

@aca88

Is Hybrid Script executable in Python? I thought the proposal is that weā€™re just using Python syntax, and parsing it into TIR. I know thereā€™s software emulation for TE, but thereā€™s no mention of it in this RFC.

Thanks for the RFC and all the hard work.

I agree at a high-level that having users be able to quickly write TIR is an important problem and worth building a solution for, but Iā€™m not sure this is the right design.

To me a feature like this should be purely about providing useful syntax for users to write down TIR programs but there are couple challenges with the proposed design. Currently it seems that this:

  1. provides a brand new DSL that users with seemingly complex scoping rules designed to fit into Pythonā€™s relatively broken scoping rules
  2. provides a mechanism for extending the DSL with concepts that are not needed from my PoV i.e special statements
  3. allows people to overload behaviors in that DSL to customize the elaboration and pretty printing of TIR.

The challenge here is we are mixing the definition of a new DSL, its syntax, and its semantics all at once, and I think a different design could better tease these apart and clarify some of the challenges I raise.

For example a straight forward decorator based on AST rewriting avoids nearly all the problems proposed in the above RFC, no need for with, concise scoping or special statements.

Even in a decorator based approach we can maintain a 1-1 correspondence between the printable syntax and the user input. In both approaches if you desugar or transform the underlying AST in anyway you will result in a different pretty printed program either way. Resguarging is a hard problem that is does not have a solution even in academic literature.

The other challenge is this approach (as far as I can tell) allows people to mix arbitrary Python and the DSL which further necessitates the inclusion of Python to run the compiler in an ahead of time mode. I know you may not have considered this (as it is a deployment not UX concern), but this is a long term worry of mine as we maintain and grow TVM into more production environments.

A final concern is the overuse of registries and decorators in TVM which often results in code which challenging to read and debug due to its non-local nature.

2 Likes

Yeah I think the registry may be not the central part of the design, and can be removed using a global dict or something like that for including all the intrins / special statements, so that people dont have to search the entire codebase for a certain definition

Great discussions. I also want to share some of my thoughts. First of all, I believe the current approach is ineed a decorator based approach(rather than a staging DSL). So the content of the printout Module should be decorated by @tvm.hybrid.module. When there is a mix of python code that does not match the defined syntax, an error will be raised.

In the case of usecases, there are two ways to view the hybrid script:

  • V0: As a effective DSL to write the TIR related functions in python and ingest.
  • V1: As faithful python-ast based syntax for the TIR itself (that can be read back).

My take is that the hybrid script aims to achieve the goal of V1. Which means there is no DSL that stages to TIR, but rather the language itself is TIR, and the mechanisms(special stmt and with scope) are introduced to support the TIR specific constructs(AttrStmt, Assert) that comes with scoping.

I agree that registry based approach can be a bit tricky. In this case it seems to be a mechanism to de-couple implementation and extension. To address the concern, we can simply force all the registration to happen at a single place e.g. (hybrid.scope_handler)

1 Like

Itā€™s worth thinking as well how to convenient people who write shape functions in the new TIR design. Ideally we need as simplified syntax as possible. Will be happy to see a few examples :slight_smile:

1 Like

@weberlo @jroesch

No, Hybrid Script is not executable in Python, and we donā€™t actually execute Hybrid Script to get TIR. We parse the hybrid script into Python AST, and construct a corresponding TIR by visiting the Python AST.

We want to clarify that the intension of Hybrid Script is to serve as a faithful text format repr of any TIR functions, rather than creating a DSL that ingests into TIR, but not necessarily represent any TIR.

We add @tvm.hybrid.script decorator over a class/function (which is the Hybrid Script for IRModule/PrimFunc), and if we reference that class/function in Python code, we will get the parsed TIR object.

There are two usage of Hybrid Sciprt I can see,

  • One important use of Hybrid Script is to inspect IR with more flexibility. For example, this is the printed Hybrid Script of opt_gemm.py after schedule normalizing phase.
@tvm.hybrid.script
class Module:
    def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "mmult", "tir.noalias": True})
        A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
        # body
        tir.attr(C_1, "realize_scope", "")
        tir.realize(C_1[0:1024, 0:1024])
        for x in tir.range(0, 1024):
            for y in tir.range(0, 1024):
                C_1[x, y] = tir.float32(0)
                for k in tir.range(0, 1024):
                    C_1[x, y] = (C_1[x, y] + (A_1[x, k]*B_1[k, y]))

mod = Module() # parse the hybrid script and get an IRModule it represents
func = mod["mmult"] # get the PrimFunc
func = tvm.build(func)

We can do modifications we want on this script and read it back to get the modified TIR. I think this process is mainly used in internal development. Users who only want to use tvm to compiler programs donā€™t have to care about the internal state of TIR. So something like ā€œhookā€ may be unnecessary

  • Another important use of Hybrid Script is to provide high level tensor program descriptions for user usage, like te.compute(), in the future.

If we merely intend to inspect the IR, making Hybrid Script readable and parsable is sufficient. But to achieve this goal we donā€™t have to be limited in the Python syntax. An important reason to use Python syntax is that we want to make Hybrid Script follow userā€™s habit of writing a normal program.

te.compute() is affected by the functional programming pipeline design in Halideā€™s algorithm description. It is an ideal model for describing common tensor programs weā€™ve encountered. But its expressiveness is limited.

As we can see in the above opt_gemm example, the Hybrid Script is close to what normal users write if they want to declare a GEMM operator. It is more complicate than a simple te.compute((M, N), lambda x, y: te.sum(A[x,k] * B[k,y], axis=k), name='C'). But it also breaks the limit of te.compute.

for x in tir.range(0, 1024):
    for y in tir.range(0, 1024):
        C_1[x, y] = tir.float32(0)
        for k in tir.range(0, 1024):
            C_1[x, y] = (C_1[x, y] + (A_1[x, k]*B_1[k, y]))

@jroesch Hence, we would like to introduce concise scoping and special statements to make Hybrid Script writable.

  1. For example, if we want to represent an ā€˜AllocateNodeā€™, we can simply do
with tir.allocate(tir.var(name_hint="packedB", dtype="handle"), "float32", [16, 16, 16], True)

But normally we want to use a single name_hint to reference a Var. Thus we can print

packedB = tir.var("handle") # Stmt in Python AST, 
                            # but doesn't correspond to a StmtNode in TIR
with tir.allocate(packedB, "float32", [16, 16, 16], True):

which is more natural.

  1. For example, if we want to represent a BufferReailzeNode, we have to print Buffer and bounds(Array of Range). We can simply do
with tir.realize(tir.Buffer([16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1), 
                 [tir.Range(0,16), tir.Range(0,14), tir.Range(0,14), tir.Range(0,32), tir.Range(0,16), tir.Range(0,16)])

It is a faithful repr of BufferRealizeNode. But normally we want to write

Conv_1 = tir.buffer_bind(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1)
with tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16]):

The above two examples can serve as a motivation for special statements.

@weberlo, as for concise scoping, it intends to make Hybrid Script more writable as well.

Actually in the current implementation, LetStmt and AssertStmt are not allowed in concise scoping because their concise modes are handled differently.

We use var: type=Expr rather than tir.let() to represent LetStmt.

We use assert condition, msg rather than tir.Assert() to represent AssertStmt.

assert (num_args == 3), "mmult: num_args should be 3"
arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0)

Besides above examples, scoping is also necessary because TIR itself contain certain nodes with special scopes, and we are using the with syntax to denote these new scope boundaries.

We clarify that register_hook is only used for internal development and we do not allow user to customize it.

The registered functions will be looked up when we visit Call in Python AST. If we can not find the function, we may fall back to intrin Ops registered in backend c++.

We use registry to register new TIR nodeā€™s representation when TIR evolves in the future.

Thanks for this proposal. What is the relationship between this TIR hybrid script and existing hybrid script? Though it adds new feature to allow user to dump TIR into python code, it looks to me the syntax is more complicated than existing hybrid script. From a te developerā€™s perspective, syntactic sugar provided by hybrid script is one big reason over IR builder. How complicated it is to write operators such as NMS using this new design?

1 Like

Thank you for your comments, @kevinthesun!

Properly speaking, TIR hybrid script is not a supperset of existing hybrid script. The latter supports software emulation, but the former doesnā€™t.

Current TIR hybrid script serves as a faithful text format of any TIR function that can be printed out and parsed back. Hence, in order to express all the information in TIR, it is more complicated than existing hybrid script.

At present, TIR hybrid script can not be used as a way to write high level tensor operators and then do schedule, since current schedule doesnā€™t operate on TIR directly. Instead, current schedule should be combined with te DSL to generate TIR. But we are working on new scheduling primitives to break this limit.

OK. Current te compute + schedule support ā€œregularā€ te operators while hybrid script supports more general programming style operators such as NMS. New TIR hybrid script mainly aims to extend te + schedule, such as dumping TIR back to python code. Is this accurate?

2 Likes

Why do we call this ā€œhybrid scriptā€. I think the original name comes from the fact that the old hybrid script can be both directly run as python and parsed as TIR. But this version is only a text format for TIR.

If this propose is only a text format, why do we embed it into pythonā€™s AST? To my understanding, pythonā€™s AST is not stable. We cannot run static check on it either. In contrast, a standalone text format without python can also achieve the goal of round trip, inspection, debugging, and manual tweaking.

3 Likes

Thank you for your comment, @merrymercy!

We call this ā€œhybrid scriptā€ because we intend to substitute the old hybrid script with current one. I think maybe we can consider another name if appropriate @tqchen

It is indeed a text format, but we also hope to enhance the syntax to make it as easy to use as old hybrid script and te.compute in the future. Hence, we choose Python which is popular in ML community.

I agree that Pythonā€™s AST is not stable, and we should keep it in mind in later development. Now the syntax we use is a really simple subset of Python, which in my opinion may not change abruptly.

I do believe using python AST is useful, as we can view it(and build features around) in several ways:

  • V0: Faithful text format for the TIR. In this case, there is some value to embed things in python:
    • User can directly write AST rather than strings in python to construct these IRs.
    • The code works with the python syntax highlighter.
  • V1: Enhanced syntax sugars: After we achieved the goal of V0, we can then bring some enhanced syntax sugars that are more concise. The code below shows an hypothetical example:
@tvm.hybrid.script
def myfunc(a, b):
    A = tir.buffer_bind(a, (100, 100), "float32")
    B = tir.buffer_bind(a, (100, 100), "float32")

    for i, j in grid(100, 100):
        B[i, j] = A[i, j] + 1

Notably, the grid construct is a sugar that expands to two for loops, and we might omit other things like realize etc.

The syntax of V1 is intended to make things easy for developers to concisely construct TIRs. There is a fixed rule to lower these sugars to a well-defined TIR program. Importantly, V1 needs to be consistent with V0 ā€“ any V0 program is also a valid V1 program. We can also developer smart printers to potentially print out sugared versions which are more concise.

  • V2: potential meta programming. We might want to develop meta-programming models to construct a collection of hybrid scripts, and use of some the python features for macro expansion.

So to sum up, while the text repr is one of the goals of the hybrid script and serves as an anchor for the design. Our ultimate goal should be building a single, unified infrastructure that serves as:

  • An text repr for the IR.
  • Sugars(DSLs) to construct the IR and compute themselves
  • Potential user interface for specifying the program.
2 Likes

From my past experience, relying on full feature of Python AST is not sustainable approach. For example, AST in python 3.8 differs from 3.7 for no reason, which crashes all my scriptsā€¦It is definitely important to find some stable subset.

3 Likes

One way to achieve this is to define a stable subset of AST object, and then build canonicalization parser that canonicalize the AST to the stable one, and then implement the parse on top of the stable ast.

1 Like