Relax Python Module Design

Now we have finished M0 and M1 part of the plan!

You can check the PR here.

Below is the example of IRModule with pyfunc:

import tvm
from tvm import relax, tir
from tvm.relax.base_py_module import BasePyModule
from tvm.script import ir as I, relax as R, tir as T
from tvm.runtime import Device
import torch


@I.ir_module
class IRModuleWithPyFunc(BasePyModule):
    """Example IRModule with Python function.
    The base class BasePyModule implements the logic of cross-function calls
    and JIT compilation in Python.
    We only allow Python functions in IRModules that subclass the BasePyModule.
    """

    @I.pyfunc
    def python_add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Python function that can be called from Relax functions."""
        # Convert inputs to TVM NDArrays via DLPack
        x_tvm = self._convert_pytorch_to_tvm(x)
        y_tvm = self._convert_pytorch_to_tvm(y)
        
        # Call the compiled TIR function
        result = self.call_tir(self.add_tir, [x_tvm, y_tvm], 
                             out_sinfo=R.Tensor((5,), "float32"))
        
        # Convert result back to original format
        return self._convert_tvm_to_pytorch(result)

    @T.prim_func
    def add_tir(
        var_x: T.handle,
        var_y: T.handle,
        var_out: T.handle,
    ):
        x = T.match_buffer(var_x, (5,), "float32")
        y = T.match_buffer(var_y, (5,), "float32")
        out = T.match_buffer(var_out, (5,), "float32")
        
        for i in range(5):
            out[i] = x[i] + y[i]

    @R.function
    def main_relax(x: R.Tensor((5,), "float32"), 
                   y: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
        return R.add(x, y)


def main():
    """Main function showing IRModule with Python function support."""
    # Create the IRModuleWithPyFunc instance
    module = IRModuleWithPyFunc()
    
    # Execute DLPack conversion
    x_torch = torch.randn(5, dtype=torch.float32)
    y_torch = torch.randn(5, dtype=torch.float32)
    
    # Convert via DLPack
    x_tvm = module._convert_pytorch_to_tvm(x_torch)
    y_tvm = module._convert_pytorch_to_tvm(y_torch)
    
    # Convert back
    x_back = module._convert_tvm_to_pytorch(x_tvm)
    y_back = module._convert_tvm_to_pytorch(y_tvm)
    
    # Execute cross-function calls
    tir_result = module.call_tir("add_tir", [x_torch, y_torch], 
                                out_sinfo=R.Tensor((5,), "float32"))
    relax_result = module.main_relax(x_torch, y_torch)
    python_result = module.python_add(x_torch, y_torch)
    
    return module, (x_torch, y_torch, x_tvm, y_tvm, x_back, y_back), (tir_result, relax_result, python_result)


if __name__ == "__main__":
    main()



# Example usage with verification:
# result = main()
# assert result is not None, "Function should return results"
# module, dlpack_results, cross_call_results = result
# assert len(dlpack_results) == 6, "DLPack results should contain 6 elements"
# assert len(cross_call_results) == 3, "Cross-call results should contain 3 elements"
2 Likes