Unified Python First Compilation Flow through tvm.compile

One recent take we have is to start focusing on a clean modular compiler that have the following properties.

  • Bring models into an IRModule that contains both relax and tensor program functions.
  • Enable customizable pipelines that rewrites the graph, partial lowering of operators.
  • Empower downstream projects to add necessary customizations on any part of the compiler pipeline, maintaining performance.

As of now we have streamlined the approach for relax. It is good to revisit how we can think the overall project wholsitic and capture the lessons we have in the past. The overall goal include:

  • G0: Python first development
    • Every everyone to customize the compilation pipeline and build on top by copy-past editing in python.
    • Clear first class roundtrippable IR that embeds as pythonic DSL at every stage.
  • G1: Universal deploy experience with full coverage It is important for the final execution to cover all features by default when possible enabled through runtime (aka Tuple, Function, NDArray). The API should feel the same across python and possibly non-python env.
  • G2: Streamlined experience with right expectations Although there are many ways to customize, have one clear streamlined flow that everyone can follow.

The expectations is that we cannot develop compilers for all expected use-cases, but still would like to provide a meaningful basis that we can iterate, customize, and build on top.

Proposed API

Based on these goals, let us use the following API

class Executable:
    """The executable object emitted by the tvm.compile"""
    def jit(
        self, fcompile=None, addons=None, **kwargs
    ) -> tvm.runtime.Module:
        """Just-in-time compile and link the modules.
        
        The Executable returned by tvm.compile may not be directly
        runnable as they may contain cuda source files and objects that
        are yet to be compiled and linked.
        """
        if self._jitted_mod is None:
            # run jit
        return self._jitted_mod
        
    def __getitem__(self, name):
        return self.jit()[name]
    
    def export_library(
        self,
        file_name: str,
        fcompile: Optional[Union[str, callable]] = None,
        workspace_dir: Optional[str] = None,
        **kwargs,
    ) -> Any:
        """Export the executable to a library which can then be loaded back.
        """
        pass
        

tvm.compile(
    mod: IRModule , 
  target: Optional[Target] = None, 
  *, 
  relax_pipeline: Optional[Union[Pass, Callable, "auto"]]="auto",
  tir_pipeline: Optional[Union[Pass, Callable, "auto"]]="auto",
  system_lib_prefix: Optional[str]=None
) -> runtime.Executable:
   """Compile a ir module into runtime module
   
   Parameters
   ----------
   mod: 
       Input IRModule
   
   target: 
       The target we are interested in. For multi-target build,
       target can be directly annotated inside the IRModule.
       
   relax_pipeline: 
       The pipeline we are interested in applying for relax functions.
       User can pass in None to indicate a minimum build,
       in such case the users are expected to run passes before hand.
       "auto" mode will pick the target dependent pipeline based on 
       target setup. 
       
   tir_pipeline:
       The pipeline we are interested in applying for tir functions.
       User can pass in None to indicate a minimum build, 
       and in such case users are expected to run passes before hand.
       "auto" mode will pick the target dependent pipeline based on 
       target setup. 
   
   system_lib_prefix:
        Provide the prefix to trigger a system lib build 
        with particular prefix key. Used for bundling libraries into
        the target system lib like wasm.
   """   
   # pesudo logic
   if contains_relax(mod):
             # will call into run_tir_build inside
       return relax.build(
           mod, target,
           relax_pipeline=relax_pipeline, 
           tir_pipeline=tir_pipeline
          )
      return tir.build(mod, target, tir_pipeline=tir_pipeline)
    
  def example():
      @tvm.script.ir_module
      class MyModule:
          @T.prim_func
          def add_one(
                  X: T.Buffer((4,), "float32"),
                            Y: T.Buffer((4,), "float32")		           
          ):
              for i in range(4):
                      Y[i] = X[i] + 1
                
          @R.function
          def main(x: R.Tensor((4, ), "float32")):
             cls = MyModule
             with R.dataflow():
                 y = R.call_tir(cls.add_one, [x], R.Tensor((4,), "float32"))
                 return y
       
       
       # jit can be skipped as it can be performed automatically once 
       # user start to access the function via getitem
       lib = tvm.compile(MyModule, "llvm")
       x = tvm.nd.array(np.arange(4, dtype="float32"))
       y = tvm.nd.empty(4, dtype="float32")
       # each function in module maps to a function at runtime
       # can directly access TIR function)
       lib["add_one"](x, y)
       
       # if we would like to access the relax function, wrap an executor
       vm = tvm.runtime.Executor(lib, tvm.cpu())
       print(vm["main"](x).numpy())
       # the public TIR function continues to be accessible
       vm["add_one"](x, y)

The key takes include:

  • G0: All passes and compilation, include tir passes constructed through python to enable customization and possibly debug.
  • G1: At runtime, there is a one to one mapping between function in MyModule(the script form) and function can be accessed through runtime. The signature maps to the function signature of the result function. The same experience across runtimes
    • TIR functions can be immediately accessed
    • high-level relax function requires an executor with list of devices to be created before accessing them
  • G2: tvm.compile as unified function to for all compilation flows, include tir and relax functions. All external usage and tutorials are built round of the tvm.compile

General Executor Interface and Partial Graph AOT

As of now relax VM is being used as the single point for backing the executor the high-level relax functions, mainly for its full coverage. One valid question is to ask what about about other possible alternatives. First of all, the executable export mechanism still enables building other potential mechanism of executions with a general rule of thumb:

  • For model executable, they can expose a create_executor function that takes a list of devices and return a runtime.Module with necessary functions

For cases where we would still like to further compile the extra steps to think about cost reduction, we can do partial graph AOT inside the current programming model

@tvm.script.ir_module
class Before:
    @R.function
    def main(x: R.Tensor((4, ), "float32")):
         cls = MyModule
         v0 = alloc_storage()
         v1 = alloc_storage()
         v2 = alloc_storage()
         cls.tir_fn1(x, v0)
         cls.tir_fn2(v0, v1)
         cls.tir_fn3(v1, v2)
         return v2
              
@tvm.script.ir_module
class After:
      @R.function
      def graph_init():
           v0 = alloc_storage()
         v1 = alloc_storage()
         v2 = alloc_storage()
         return (v0, v1, v2)

    @R.function
    def main(x: R.Tensor((4,), "float32")):
         cls = MyModule
         # storage is created once and kept aorund
         storage_set = R.call_packed(
            "alloc_cached_storage_set", cls.graph_init, 
            key="mymod"
         )
         # need to enhance the TIR PackedFunc to natively handle
         # runtime.Array to access element
         # near zero overhead in calling
         # most cost inside can be optimized away
         cls.partial_graph_aot_fn1_fn2_fn3(x, storage_set)
         return v2

Partial graph AOT means the VM can still be used for allocation and caching which needs more advanced hanlding, while the general execution steps after allocation can be handled through tir which generally do not do any allocation.

This way allows us to gradually move part of the graph into more compiled mode if necessary. However, as of now relax vm itself is performant enough for most cases, and for GPU runtimes we have CUDAGraph to handle the specific cases. So partial graph AOT mainly serves as a guideline for potential future needs.

Notably, once the model is in near partial graph AOT form, the product likely can be consumed by downstream project to create more stripped downed version that manually does storage allocation and execution separately if needed.

5 Likes

Awesome! would be great to see discussions about whether we can adapt tvm.compile with torch.compile to convert a torch sub fx graphs or other irs into Relax.

1 Like