Relax Python Module Design

Motivation

As machine learning models—especially large language models—continue to grow in scale, there is an increasing demand for ML compiler runtimes to integrate more deeply with the Python ecosystem. Python-based frameworks like PyTorch provide rich operator libraries, including features like distributed communication via torch.distributed, which scales efficiently across GPUs and nodes. These resources are already widely adopted and well-supported, making them ideal candidates for reuse within compiler runtimes.

In TVM, computational graphs are described using Relax in IRModules. While TVMScript allows us to express Relax functions in a Python-like syntax, these functions are not directly executable in Python. To run a Relax function, one must compile the entire IRModule and load the resulting executable via the virtual machine (VM).

To better leverage Python’s runtime environment and enrich TVM’s flexibility, we propose adding native support for Python functions in IRModules and TVMScript for platforms with Python support. These Python functions—marked with the @py_func decorator—can be executed directly in Python, using standard PyTorch tensors as inputs and outputs. Similar to Relax functions, they represent computational graphs, but with the added benefit of direct, step-by-step execution with Python. Unlike Relax functions which need to be compiled before running, Python functions will not be compiled, and can instead directly run with the Python environment.

Beyond reuse of Python and PyTorch implementations, supporting Python functions in TVMScript significantly enhances the debugging experience. Traditional compilers treat computation graphs as monolithic entities, making it difficult to inspect intermediate tensor values. As models grow in complexity, this limitation becomes more pronounced. With Python functions, debugging is as simple as inserting a print statement. Users can also make quick, manual edits to Python functions and immediately observe the results—greatly improving the development and debugging workflow.

Key Designs

Cross-level calls

Python functions in Relax are designed to be cross-level, meaning they can interoperate with Relax functions, TIR functions, and TVM packed functions. This two-way interoperability allows:

  • Python functions to invoke Relax/TIR/packed functions.
  • Relax functions to invoke Python functions via R.call_py_func.

To support this, we use DLPack for seamless conversion between TVM NDArrays and PyTorch Tensors, enabling data to flow between different runtime environments with minimal overhead.

Just-in-time (JIT) compilation

If an IRModule includes any Python functions, we delay the compilation of TIR and Relax functions using JIT compilation. This means:

  • TIR and Relax functions are not compiled when the TVMScript is parsed.
  • Compilation happens only when we instantiate the IRModule, at which point:
    • TIR functions are compiled and stored in the instantiated module.
    • A Relax VM is created to execute compiled Relax functions.

This JIT strategy allows more flexible, late-stage modifications and integration with the Python runtime.

Conversion between Relax functions and Python functions

Because both Relax and Python functions describe computational graphs, we introduce a new IRModule printer that converts Relax functions into Python functions. This allows users to:

  • Avoid writing Python functions manually.
  • Convert Relax IR into readable and executable Python code.
  • Debug or deploy intermediate-stage Relax programs directly in Python/PyTorch.

During this conversion:

  • High-level Relax operators (e.g., R.nn.relu) are mapped to corresponding PyTorch APIs (e.g., F.relu).
  • call_tir and Relax function calls are handled by converting PyTorch tensors to/from DLPack format and passing them to compiled functions.
  • call_dps_packed is executed by retrieving the packed function via tvm.get_global_func and invoking it with DLPack-wrapped tensors.

A key feature is that this conversion can happen at any stage of the compilation process. For example:

  • At early stages, users can convert a Relax function to Python to test against PyTorch implementations.
  • At later stages, when most of the module is lowered to TIR, the same conversion allows testing or deployment using the PyTorch runtime.

In the future, we may also implement tracing Python functions back to Relax functions with some of the PyTorch infrastructure (such as FX or exported programs). Nevertheless, as a first step, we are not going to finish this back-tracing for now.

Example

Below is an example of IRModule TVMScript that comes with a Python function.

@I.ir_module
class IRModuleWithPyFunc(BasePyModule):
    """Example IRModule with Python function.
    The base class BasePyModule implements the logic of cross-function calls
    and JIT compilation in Python.
    We only allow Python functions in IRModules that subclass the BasePyModule.
    """

    @I.py_func
    def main(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        n = x.shape[0]
        lv = self.call_tir(self.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32"))
        lv1 = F.relu(lv)
        lv2 = self.call_dps_packed("my_softmax", [lv1, 1], out_sinfo=R.Tensor((n, 20), "float32"))
        lv3 = self.my_identity_func(lv2)
        gv = lv3
        return gv

    @T.prim_func
    def matmul(
        var_A: T.handle,
        var_B: T.handle,
        var_C: T.handle,
    ):
        n = T.int32()
        A = T.match_buffer(var_A, (n, 16), "float32")
        B = T.match_buffer(var_B, (16, 20), "float32")
        C = T.match_buffer(var_C, (n, 20), "float32")
        for i, j, k in T.grid(n, 20, 16):
            with T.block("block"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = T.float32(0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

    @R.function
    def my_identity_func(x: R.Tensor(("n", 20), "float32")) -> R.Tensor(("n", 20), "float32"):
        return x

    @R.function
    def my_relax_func(
        x: R.Tensor(("n", 16), "float32"), w: R.Tensor((16, 20), "float32")
    ) -> R.Tensor(("n", 20), "float32"):
        cls = IRModuleWithPyFunc
        n = T.int64()
        with R.dataflow():
            lv = R.call_py_func(cls.main)
        return x

The code below shows a basic usage of the Python function in this IRModule. We first instantiate the IRModule with the CUDA device, and then pass in PyTorch tensors to the main function.

device = tvm.cuda()
py_mod = IRModuleWithPyFunc(device)  # Create python module, with TIR func compiled
n = 233
x = torch.randn(n, 16).to(torch.float32)
w = torch.randn(16, 20).to(torch.float32)
out = py_mod.main(x, w)

The Python function can be converted from an equivalent Relax function with our printer

@I.ir_module
class MyModule:
    @R.function
    def main(
        x: R.Tensor(("n", 16), "float32"), w: R.Tensor((16, 20), "float32")
    ) -> R.Tensor(("n", 20), "float32"):
        cls = MyModule
        n = T.int64()
        with R.dataflow():
            lv = R.call_tir(cls.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32"))
            lv1 = R.nn.relu(lv)
            lv2 = R.call_dps_packed(
                "my_softmax", [lv1, R.prim_value(1)], out_sinfo=R.Tensor((n, 20), "float32")
            )
            gv = lv2
            R.output(gv)
        return gv

    ...


print(MyModule.as_pyfunc("main"))   # Prints out a string that as shown in "IRModuleWithPyFunc"

And under the hood of this IRModule, the BasePyModule is defined as following:

class BasePyModule:
    def __init__(
        self,
        ir_mod: tvm.IRModule,
        device: tvm.runtime.Device,
        target: Optional[tvm.target.Target] = None,
    ):
        self.compiled_tir_funcs = {}
        self.extern_funcs = {}
        self.tir_func_names = []
        self.relax_func_names = []
        self.relax_vm = None

        # Compile all the TIR functions in the class.
        if target is None:
            target = tvm.target.Target.from_device(device)

        # Apply pass that updates all TIR functions to be public, with global symbols attached.
        # ir_mod = VisibilityUpdater()(ir_mod)

        for gv, func in ir_mod.functions_items():
            if isinstance(func, tir.PrimFunc):
                self.tir_func_names.append(gv.name_hint)
            elif isinstance(func, relax.Function):
                self.relax_func_names.append(gv.name_hint)

        # Compile the IRModule Relax and TIR functions in the IRModule.
        # TIR scheduling will be done with dlight rules in the relax pipeline.
        exec = tvm.compile(
            ir_mod,
            target=target,
            relax_pipeline=relax.get_default_pipeline(target),
            tir_pipeline=tir.get_default_tir_pipeline(target),
        )
        self.relax_vm = relax.VirtualMachine(exec, device)

        # Register the wrapped function to the class,
        # so that it can be called like a normal python function
        # with torch tensor arguments and return values.
        for func_name in self.relax_func_names:

            def _wrap_relax_func(*args):
                # Convert args to tvm ndarray with dlpack...
                # args = ...
                out = self.relax_vm[func_name](*args)
                # Convert out to torch tensor...
                # out = ...
                return out

            setattr(self, func_name, _wrap_relax_func)

        # Lookup compiled TIR functions from the VM
        for func_name in self.tir_func_names:
            self.compiled_tir_funcs[func_name] = self.relax_vm[func_name]

    def call_tir(self, tir_func, args, out_sinfo):
        out = (
            [torch.empty(out_sinfo.shape, dtype=out_sinfo.dtype)]
            if not isinstance(out_sinfo, list)
            else [torch.empty(sinfo.shape, dtype=sinfo.dtype) for sinfo in out_sinfo]
        )

        if not isinstance(tir_func, tir.PrimFunc):
            raise ValueError(f"Input function {tir_func} is not a tir.PrimFunc")
        func = self.compiled_tir_funcs[tir_func.__name__]

        # logic that converts args and out to tvm ndarray with dlpack...

        func(*args, *out)
        return out

    def call_dps_packed(self, func_name, args, out_sinfo):
        out = (
            [torch.empty(out_sinfo.shape, dtype=out_sinfo.dtype)]
            if not isinstance(out_sinfo, list)
            else [torch.empty(sinfo.shape, dtype=sinfo.dtype) for sinfo in out_sinfo]
        )

        if func_name not in self.extern_funcs:
            func = tvm.get_global_func(func_name)
            self.extern_funcs[func_name] = func
        else:
            func = self.extern_funcs[func_name]

        # logic that converts args and out to tvm ndarray with dlpack...

        func(*args, *out)
        return out

Engineering Plan

Here are the engineering milestones to support Python functions in TVMScript and IRModule.

  • M0. TVMScript parser enhancement
    • M0a. Python functions with decorator @I.py_func.
      • We store two “formats” for Python functions in IRModule.
      • The first format is a raw string, which is used for TVMScript printing.
      • The second format is a captured PackedFunc in TVM runtime. This format is used for cross-function calls (being called from other Relax functions). (low priority)
      • Each Python function is represented as an ExternFunc node in TVM, where the raw string and the captured PackedFunc are stored in the attrs field as attributes
    • M0b. IRModule subclassing the BasePyModule: class IRModuleWithPyFunc(BasePyModule).
  • M1. Complete BasePyModule
    • M1a. Format conversion between Torch tensors and TVM NDArray through DLPack.
  • M2. TVMScript printer for IRModules with Python functions
    • Print the IRModule in the proper format as shown in examples.
    • Need a high-level operator map from Relax op to PyTorch op
    • Need to handle symbolic shape (e.g., n = x.shape[0]) in the example above.
  • M3. Introduce R.call_py_func to Relax (lowest priority)
    • The behavior of this primitive at runtime is to invoke the corresponding PackedFunc of the specified Python function.
3 Likes

I like this design and I think this is a great step to collaborate with existing python ecosystem.

I had one quick question, F.relu we will translate into R.nn.relu, however, when we execute it, will it call torch F.relu ? Then if we compile it, it will translate into Relax? could we explain it more what happened in different stages? @MasterJH5574

1 Like

If we run/execute the Python function, then PyTorch’s F.relu(lv) will be called. In the future if we support tracing these Python functions back to Relax functions, F.relu will be translated to R.nn.relu, and it will then be optimized with TVM. Does this sound clear to you?

1 Like

Yes. Clear. Let’s see what the magic will have if we land it.