[Discussion] TIR macros

Issue: Hand-written PrimFuncs may contain repeated chunks of code, or long pieces of code that make the PrimFunc harder to read.

When working on custom implementations of some operators for our targets, we ended up with prim funcs that are long and hard to read. The main contributor to that were long pieces of code that weren’t particularly complex in terms of the logic, but took many lines of code. For example

T.call_packed(
    "kernel",
    T.tvm_stack_make_array(
        T.address_of(A[i, j], ...),
        T.tvm_stack_make_shape(...),
        0,
        1,
        2,
        3,
        dtype=...
    ),
    # More "make array" code...
    ...
)

The single call_packed with all arguments took over 50 lines. There are also common blocks (T.block) that were repeated in a few places.

Normally, you’d extract them into a subroutine, but calling subroutines in TIR can result in code like the one shown above.

Idea: Designate some function as “macros” and include them into other TIR code

The macros would never have any actual calls generated to them. Instead they would be pasted (inlined) into the TIR.

For the example above, we could have

@T.macro
def call_kernel(param):  # No type annotations required
    T.call_packed(
        "kernel",
        T.tvm_stack_make_array(
            T.address_of(param, ...),
            T.tvm_stack_make_shape(...),
            0,
            1,
            2,
            3,
            dtype=...
        ),
        # More "make array" code...
        ...
    )

and then in the original prim func:

T.include(call_kernel, A[i, j])

The macros would not need to have their parameters annotated—they would be pasted into the macro’s body as-is. Macros would not be allowed to reference undefined variables (to avoid accidental binding at the inclusion).

I have started working on a prototype, so I’m volunteering to implement this. I’d like to know what everyone’s thoughts about this are.

Looks like an interesting extension, my major concern is how do we properly parse them, Python (which TVMScript should have round-trip compatibility with) does not have an official syntax for macro, and there might be some un-reference variables in the macro body.

It would be kind of like C preprocessor—once a macro is included in other TIR, it cannot be outlined again. So,

@T.macro
def foo(...)
    something(...)

@T.prim_func
def bar(...)
    T.include(foo, ...)

would become

@T.prim_func
def bar(...)
   something(...)

The macros would only serve as a tool helping with organizing the original source. Once it’s applied, it cannot be undone.

I had a feeling that it is already supported though TVMscript meta-programming, perhaps @junrushao can chime in as well

Also, @Hzfengsy and @Lunderberg, since they were involved in that too.

1 Like

They are supported by the parser via meta programming, for example, the tvmscript parser finds the python method my_dequantize_op, offloads to the python interpreter for execution, and embeds the execution result back to tvmscript.

def my_dequantize_op(a: tvm.tir.PrimExpr):
  return (a >> 4) & 15

@T.prim_func
def my_prim_func(...):
  A[...] = my_dequantize_op(B[...])

It solves the problem of tvmscript parsing, but on the printer side, it will be expanded to a gigantic PrimFunc:

@T.prim_func
def my_prim_func(...):
  A[...] = (B[...] >> 4) & 15

Right, but the my_dequantize_op function cannot be written in TIR script. So you have to mix representations if you want to do it this way.

My goal is to allow snippets of TIR script to be expressed as macros.

There are two scenarios you might be interested in:

Case 1. Efficiently construct an IR (TIR/Relax) - in this case, it is possible to mix TVMScript format and plain python interpreter via TVMScript’s IRBuilder, for example, to assemble a TIR loop/block, etc: https://github.com/apache/tvm/blob/main/tests/python/unittest/test_tvmscript_ir_builder_tir.py

Case 2. You wanted to roundtrip the macro - in this case, we will need certain constructs in TIR to support printing a macro

What I’m working on doesn’t seem complicated so far. When I have a prototype PR, you can take a look and see what you think.

2 Likes

This might also have some overlap with the primfunc-to-primfunc calls I’m working on in this PR. The TIR macros could then be represented as internal function calls within an IRModule, which would be round-trip-able through TVMScript.

@I.ir_module
class my_module:
    @T.prim_func
    def main(...):
        T.func_attr({'global_symbol': 'main'})
        A[...] = my_module.my_dequantize_op(B[...])

    @T.prim_func
    def my_dequantize_op(a: T.int32) -> T.int32:
        return (a >> 4) && 15

Granted, we’d also need to add a function attribute to require inlining of these functions during lowering, as the lowered TIR probably shouldn’t contain the function calls, but that should be pretty straightforward (and is something I’m planning on implementing anyways).

Here’s what I got so far:

import tvm
from tvm.script import tir as T

@T.macro
def foo1(a, b):
    with T.block("block1"):
        a[0] = b[0]

@T.prim_func
def bar(A: T.Buffer((2,), "int32"), B: T.Buffer((2,), "int32")):
    T.include(foo1, A, B)

print(bar)
print(foo1)

Output:

$ python3 mini.py

# from tvm.script import tir as T

@T.prim_func
def main(a: T.Buffer((2,), "int32"), b: T.Buffer((2,), "int32")):
    with T.block("block1"):
        T.reads(b[0])
        T.writes(a[0])
        a[0] = b[0]
<tvm.script.parser.core.doc_core.Module object at 0x7f34098be490>

The macro is really doc.AST that gets processed when T.include is encountered. It only needs to be syntactically correct. It doesn’t have any of the constraints that T.prim_func has, including types of parameters, or type annotations. In principle, this would also be expected to work:

@T.macro
def foo1(a, b, name):
    with T.block(name):
        a[0] = b[0]
[...]
T.include(A, B, "some_block")

but it fails now because it’s a really simple prototype.

1 Like

Do we have inliner for PrimFuncs?

If a PrimFunc has a block with a certain name, and it gets inlined more than once, the target PrimFunc will have two blocks with the same name. I’m not sure what the behavior of that is when you do schedule.get_block("duplicated_name"). Part of my solution would allow giving blocks different names (hence the example with block name passed as argument). Is the duplicate block behavior defined?

At the moment, we do not. I expect that it will be pretty straight-forward to implement, using tvm::tir::Specialize (func.specialize in Python) to in-line the arguments, then replacing the function call with the body of the resulting specialized function. I was picturing this running relatively late in the lowering flow, after the blocks have been removed, but it probably should have some de-duping if used on a PrimFunc containing blocks.

Ah, got it. So where the existing meta-programming functionality works for generating PrimExpr, this would be used to allow generation of Stmt as well. In that case, I’m wondering if we could use the same machinery as the existing meta-programming.

Click to expand 3-line git diff
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
index f81f9bd9ea..920ba636ef 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -435,6 +435,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
         res.__enter__()
     elif isinstance(res, PrimExpr):
         T.evaluate(res)
+    elif isinstance(res, tvm.tir.Stmt):
+        tvm.script.ir_builder.tir._ffi_api.AddToParent(res)
     elif isinstance(res, (int, bool)):
         T.evaluate(tvm.tir.const(res))
     elif isinstance(res, tvm.relay.Call) and not res.args:
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 154a1ab3b0..3b8068e40a 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -672,6 +672,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread)
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AddToParent").set_body_typed(AddToParent);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr);

This way, any function that returns a Stmt would be treated as though that statement appears at that location.

def foo(a, b):
    @T.prim_func
    def dummy_func():
        with T.block("block1"):
            a[0] = b[0]

    return dummy_func.body

@T.prim_func
def bar(A: T.Buffer((2,), "int32"), B: T.Buffer((2,), "int32")):
    foo(A, B)

Not sure whether that simplifies the @T.macro definition, but I like the idea of it being a shorthand to generate a function that returns a Stmt. (Also, this doesn’t preserve which sections of the final PrimFunc were generated by the macro, so the printed TVMScript would contain everything explicitly.)

I got a draft PR with the first simple prototype: https://github.com/apache/tvm/pull/15238.

1 Like

I really liked this idea, and I tried to make T.macro be a syntactic sugar for your variant, but I ran into some anomalous behavior:

Source:

import tvm
import tvm.relay
from tvm.script import tir as T

def foo(a, b, n):
    @T.prim_func
    def dummy_func():
        with T.block(n):
            a[0] = b[0]

    print(dummy_func.body)   # extra print
    return dummy_func.body


@T.prim_func
def bar(A: T.Buffer((2,), "int32"), B: T.Buffer((2,), "int32")):
    foo(A, B, "block1")

print(bar)

Result (with your patch):

with T.block("block1"):
    T.reads()
    T.writes()
    A = T.Buffer((2,), "int32")
    B = T.Buffer((2,), "int32")
    A[0] = B[0]
# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((2,), "int32"), B: T.Buffer((2,), "int32")):
    with T.block("block1"):
        T.reads()
        T.writes()
        A[0] = B[0]

The issue is that the T.reads and T.writes are both empty, and the body of the dummy function show assignments to A and B follow the reads/writes. I’m guessing this is because the dummy PrimFunc is evaluated early, before parameter substitution, and the parser inserts empty reads/writes.

I liked your idea because it would avoid manual evaluation of the parameter values, but I’m not completely sure how much work it would be to address the issues I’m ran into.

I have updated the draft PR. Now passing strings works as desired. I think the only thing left is pretty-printing macros instead of <doc.Module object ...>.

I have finished the implementation and have a PR for review: https://github.com/apache/tvm/pull/15260. The draft has been closed.

2 Likes