Motivation
As machine learning models—especially large language models—continue to grow in scale, there is an increasing demand for ML compiler runtimes to integrate more deeply with the Python ecosystem. Python-based frameworks like PyTorch provide rich operator libraries, including features like distributed communication via torch.distributed
, which scales efficiently across GPUs and nodes. These resources are already widely adopted and well-supported, making them ideal candidates for reuse within compiler runtimes.
In TVM, computational graphs are described using Relax in IRModules. While TVMScript allows us to express Relax functions in a Python-like syntax, these functions are not directly executable in Python. To run a Relax function, one must compile the entire IRModule and load the resulting executable via the virtual machine (VM).
To better leverage Python’s runtime environment and enrich TVM’s flexibility, we propose adding native support for Python functions in IRModules and TVMScript for platforms with Python support. These Python functions—marked with the @py_func
decorator—can be executed directly in Python, using standard PyTorch tensors as inputs and outputs. Similar to Relax functions, they represent computational graphs, but with the added benefit of direct, step-by-step execution with Python. Unlike Relax functions which need to be compiled before running, Python functions will not be compiled, and can instead directly run with the Python environment.
Beyond reuse of Python and PyTorch implementations, supporting Python functions in TVMScript significantly enhances the debugging experience. Traditional compilers treat computation graphs as monolithic entities, making it difficult to inspect intermediate tensor values. As models grow in complexity, this limitation becomes more pronounced. With Python functions, debugging is as simple as inserting a print statement. Users can also make quick, manual edits to Python functions and immediately observe the results—greatly improving the development and debugging workflow.
Key Designs
Cross-level calls
Python functions in Relax are designed to be cross-level, meaning they can interoperate with Relax functions, TIR functions, and TVM packed functions. This two-way interoperability allows:
- Python functions to invoke Relax/TIR/packed functions.
- Relax functions to invoke Python functions via
R.call_py_func
.
To support this, we use DLPack for seamless conversion between TVM NDArrays and PyTorch Tensors, enabling data to flow between different runtime environments with minimal overhead.
Just-in-time (JIT) compilation
If an IRModule includes any Python functions, we delay the compilation of TIR and Relax functions using JIT compilation. This means:
- TIR and Relax functions are not compiled when the TVMScript is parsed.
- Compilation happens only when we instantiate the IRModule, at which point:
- TIR functions are compiled and stored in the instantiated module.
- A Relax VM is created to execute compiled Relax functions.
This JIT strategy allows more flexible, late-stage modifications and integration with the Python runtime.
Conversion between Relax functions and Python functions
Because both Relax and Python functions describe computational graphs, we introduce a new IRModule printer that converts Relax functions into Python functions. This allows users to:
- Avoid writing Python functions manually.
- Convert Relax IR into readable and executable Python code.
- Debug or deploy intermediate-stage Relax programs directly in Python/PyTorch.
During this conversion:
- High-level Relax operators (e.g.,
R.nn.relu
) are mapped to corresponding PyTorch APIs (e.g.,F.relu
). -
call_tir
and Relax function calls are handled by converting PyTorch tensors to/from DLPack format and passing them to compiled functions. -
call_dps_packed
is executed by retrieving the packed function viatvm.get_global_func
and invoking it with DLPack-wrapped tensors.
A key feature is that this conversion can happen at any stage of the compilation process. For example:
- At early stages, users can convert a Relax function to Python to test against PyTorch implementations.
- At later stages, when most of the module is lowered to TIR, the same conversion allows testing or deployment using the PyTorch runtime.
In the future, we may also implement tracing Python functions back to Relax functions with some of the PyTorch infrastructure (such as FX or exported programs). Nevertheless, as a first step, we are not going to finish this back-tracing for now.
Example
Below is an example of IRModule TVMScript that comes with a Python function.
@I.ir_module
class IRModuleWithPyFunc(BasePyModule):
"""Example IRModule with Python function.
The base class BasePyModule implements the logic of cross-function calls
and JIT compilation in Python.
We only allow Python functions in IRModules that subclass the BasePyModule.
"""
@I.py_func
def main(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
n = x.shape[0]
lv = self.call_tir(self.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32"))
lv1 = F.relu(lv)
lv2 = self.call_dps_packed("my_softmax", [lv1, 1], out_sinfo=R.Tensor((n, 20), "float32"))
lv3 = self.my_identity_func(lv2)
gv = lv3
return gv
@T.prim_func
def matmul(
var_A: T.handle,
var_B: T.handle,
var_C: T.handle,
):
n = T.int32()
A = T.match_buffer(var_A, (n, 16), "float32")
B = T.match_buffer(var_B, (16, 20), "float32")
C = T.match_buffer(var_C, (n, 20), "float32")
for i, j, k in T.grid(n, 20, 16):
with T.block("block"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@R.function
def my_identity_func(x: R.Tensor(("n", 20), "float32")) -> R.Tensor(("n", 20), "float32"):
return x
@R.function
def my_relax_func(
x: R.Tensor(("n", 16), "float32"), w: R.Tensor((16, 20), "float32")
) -> R.Tensor(("n", 20), "float32"):
cls = IRModuleWithPyFunc
n = T.int64()
with R.dataflow():
lv = R.call_py_func(cls.main)
return x
The code below shows a basic usage of the Python function in this IRModule. We first instantiate the IRModule with the CUDA device, and then pass in PyTorch tensors to the main function.
device = tvm.cuda()
py_mod = IRModuleWithPyFunc(device) # Create python module, with TIR func compiled
n = 233
x = torch.randn(n, 16).to(torch.float32)
w = torch.randn(16, 20).to(torch.float32)
out = py_mod.main(x, w)
The Python function can be converted from an equivalent Relax function with our printer
@I.ir_module
class MyModule:
@R.function
def main(
x: R.Tensor(("n", 16), "float32"), w: R.Tensor((16, 20), "float32")
) -> R.Tensor(("n", 20), "float32"):
cls = MyModule
n = T.int64()
with R.dataflow():
lv = R.call_tir(cls.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32"))
lv1 = R.nn.relu(lv)
lv2 = R.call_dps_packed(
"my_softmax", [lv1, R.prim_value(1)], out_sinfo=R.Tensor((n, 20), "float32")
)
gv = lv2
R.output(gv)
return gv
...
print(MyModule.as_pyfunc("main")) # Prints out a string that as shown in "IRModuleWithPyFunc"
And under the hood of this IRModule, the BasePyModule
is defined as following:
class BasePyModule:
def __init__(
self,
ir_mod: tvm.IRModule,
device: tvm.runtime.Device,
target: Optional[tvm.target.Target] = None,
):
self.compiled_tir_funcs = {}
self.extern_funcs = {}
self.tir_func_names = []
self.relax_func_names = []
self.relax_vm = None
# Compile all the TIR functions in the class.
if target is None:
target = tvm.target.Target.from_device(device)
# Apply pass that updates all TIR functions to be public, with global symbols attached.
# ir_mod = VisibilityUpdater()(ir_mod)
for gv, func in ir_mod.functions_items():
if isinstance(func, tir.PrimFunc):
self.tir_func_names.append(gv.name_hint)
elif isinstance(func, relax.Function):
self.relax_func_names.append(gv.name_hint)
# Compile the IRModule Relax and TIR functions in the IRModule.
# TIR scheduling will be done with dlight rules in the relax pipeline.
exec = tvm.compile(
ir_mod,
target=target,
relax_pipeline=relax.get_default_pipeline(target),
tir_pipeline=tir.get_default_tir_pipeline(target),
)
self.relax_vm = relax.VirtualMachine(exec, device)
# Register the wrapped function to the class,
# so that it can be called like a normal python function
# with torch tensor arguments and return values.
for func_name in self.relax_func_names:
def _wrap_relax_func(*args):
# Convert args to tvm ndarray with dlpack...
# args = ...
out = self.relax_vm[func_name](*args)
# Convert out to torch tensor...
# out = ...
return out
setattr(self, func_name, _wrap_relax_func)
# Lookup compiled TIR functions from the VM
for func_name in self.tir_func_names:
self.compiled_tir_funcs[func_name] = self.relax_vm[func_name]
def call_tir(self, tir_func, args, out_sinfo):
out = (
[torch.empty(out_sinfo.shape, dtype=out_sinfo.dtype)]
if not isinstance(out_sinfo, list)
else [torch.empty(sinfo.shape, dtype=sinfo.dtype) for sinfo in out_sinfo]
)
if not isinstance(tir_func, tir.PrimFunc):
raise ValueError(f"Input function {tir_func} is not a tir.PrimFunc")
func = self.compiled_tir_funcs[tir_func.__name__]
# logic that converts args and out to tvm ndarray with dlpack...
func(*args, *out)
return out
def call_dps_packed(self, func_name, args, out_sinfo):
out = (
[torch.empty(out_sinfo.shape, dtype=out_sinfo.dtype)]
if not isinstance(out_sinfo, list)
else [torch.empty(sinfo.shape, dtype=sinfo.dtype) for sinfo in out_sinfo]
)
if func_name not in self.extern_funcs:
func = tvm.get_global_func(func_name)
self.extern_funcs[func_name] = func
else:
func = self.extern_funcs[func_name]
# logic that converts args and out to tvm ndarray with dlpack...
func(*args, *out)
return out
Engineering Plan
Here are the engineering milestones to support Python functions in TVMScript and IRModule.
- M0. TVMScript parser enhancement
- M0a. Python functions with decorator
@I.py_func
.- We store two “formats” for Python functions in IRModule.
- The first format is a raw string, which is used for TVMScript printing.
- The second format is a captured PackedFunc in TVM runtime. This format is used for cross-function calls (being called from other Relax functions). (low priority)
- Each Python function is represented as an ExternFunc node in TVM, where the raw string and the captured PackedFunc are stored in the
attrs
field as attributes
- M0b. IRModule subclassing the BasePyModule:
class IRModuleWithPyFunc(BasePyModule)
.
- M0a. Python functions with decorator
- M1. Complete
BasePyModule
- M1a. Format conversion between Torch tensors and TVM NDArray through DLPack.
- M2. TVMScript printer for IRModules with Python functions
- Print the IRModule in the proper format as shown in examples.
- Need a high-level operator map from Relax op to PyTorch op
- Need to handle symbolic shape (e.g.,
n = x.shape[0]
) in the example above.
- M3. Introduce
R.call_py_func
to Relax (lowest priority)- The behavior of this primitive at runtime is to invoke the corresponding PackedFunc of the specified Python function.