Phasing out Legacy Components

Over the past year, the community has worked hard to bring in and transition to a more flexible and productive flow for ML compilers. One lesson we learned is that it is hard to build a silver bullet for everything. Additionally, given the amount of time and energy contributed by community volunteers, it is hard to build and maintain a single compiler pipeline that aims to fit all purposes, in our case, all backends and use cases. The engineering complexity inevitably grows as we try to grow the combination of models and backends and aim to fit everything into a single pipeline.

However, this does not render ML compilers useless. In fact, ML compilers are becoming increasingly important as new workloads, hardware primitives, and vertical use cases arise. Instead, the development and continuous improvement of vertical ML compilers should be part of the ML engineering process. Additionally, by enable such productive ML compiler development, we can afford to bring up vertical-specific compiler optimizations, for key use-cases like LLM and image detection models.

With that goal in mind, we still need to answer a question about “what should be the common infrastructure we provide can be shared across vertical flows”. The answer to such question has evolved since the project started. Over the past year, the community has converged toward the pattern:

  • Every program is encapsulated by an IRModule, with python-first printing/parsing support via TVMScript
  • Optimizations and lowering are implemented as composable transformations on the IRModule
  • A universal runtime mechanism(through tvm ffi) that naturally maps an IRModule to runnable component across different environments.

Throughout all these flows, TVMScript serves as a common tool to inspect and communicate the intermediate steps. By adopting this common flow, different optimizations and vertical compiler building can happen more organically. Importantly, it also allows us to strengthen the core while allowing downstream projects to add necessary customizations while making use of the existing pipelines when needed. Moving towards the lightweight flow also brings extra benefits in terms of testing. Because most of the optimizations and importing are tested via structural equality, we benefit from reduce test time and more unit-level correctness checkings.

Most of the new development as now centers around the new modular flow. In the meantime, we have been keeping the legacy components around for one year. We started to see challenges as some components get out of maintenance due to a lack of development. Additionally, because of the way some of the legacy components are structured, many tests require integration(instead of structural equality), taking much CI time and resources.

This post calls for us to move away from legacy components towards the new flow, specifically:

  • Move away from relay toward relax as the graph IR
  • Use TensorIR for tensor program schedule over te/schedule
    • te remains as a useful tool to construct tir.PrimFunc, but not necessarily the scheduling part.
    • Use dlight-style IRModule⇒IRModule transform for rule based scheduling that is compatible with the modular flow.
  • Use MetaSchedule for autotuning, over autotvm and autoschedule

We encourage community contributions that centralizes the new flow, including improving frontends and modular optimizations based on the new approach. Importantly, these latest improvements will have less overhead for testing and technical coupling in general, as we can structure most of them via structural equality tests via TVMScript and IRModule⇒IRModule mechanism without introducing new mechanisms. Feel free to share thoughts.

As we gradually phase out the legacy components, they will remain available through release branches and taking maintainace patches. Coming back to the context, field of ML/AI is moving even faster, and we have gone several major changes in the recent wave of GenAI. These challenges are unique, and calls for a need for ML projects to reinvent themselves to stay relevant, or becoming irrelevant. After one year more development and learnings of the new flow, it is a right time for us to start the move.

6 Likes

Completely agree with these perspectives. Another observation I have is that projects developed based on TVM are often not straightforward; they typically require hacking the underlying TVM code. For example, in the Ladder project (based on Welder), we added support for MFMA and HIP code generation to TVM and introduce our own fuseops pass at cpp. In the BitBLAS project, we introduced two or three additional ugly schedules to hack certain layout operations in order to achieve better performance (Slides), for example:

We also realized that relying on schedules made it difficult to describe some operators, such as FlashAttention, T-MAC, and Stream-K. Therefore, we designed some syntactic sugar for TIR to use it as a Triton-like language (transforming schedules into annotations, such as Pipeline and Layout Transform), for example, triton is hard to describe dequantize gemm, but with our sugar syntax, we can dispatch the dequantize part into thread programming instead of triton-like block programming.

@T.prim_func
def main(
        A: T.Buffer(A_shape, dtypeAB),
        B: T.Buffer(B_shape, storage_dtype),
        C: T.Buffer((M, N), dtypeC),
):
    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
        A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
        B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
        B_local = T.alloc_fragment([8], storage_dtype, "local")
        B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local")
        B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB)
        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
        
        T.annotate_layout(
            {
                A_shared: make_swizzle_layout(A_shared),
                B_shared: make_swizzle_layout(B_shared),
            }
        )
        
        # Improve L2 Cache
        T.use_swizzle(panel_size=10)

        t = T.thread_binding(0, threads, thread="threadIdx.x")
        
        T.clear(C_local)
        
        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            T.copy(A[by * block_M, k * block_K], A_shared)
            
            for i, j in T.Parallel(block_N, block_K // num_elems_per_byte):
                B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j]

            for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)):
                for v in T.vectorized(0, 4):
                    vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte)
                    vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte)
                    B_local[v] = B_shared[vi, vj]
                for v in T.serial(0, 8):
                    B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)(
                        num_bits,
                        B_local[v // 2],
                        v % 2,
                        dtype=dtypeAB,
                    )
                for v in T.vectorized(0, 8):
                    vi = (i * threads * 8 + t * 8 + v) // (block_K)
                    vj = (i * threads * 8 + t * 8 + v) % (block_K)
                    B_dequantize_shared[vi, vj] = B_dequantize_local[v]
            T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
        T.copy(C_local, C[by * block_M, bx * block_N])

(it’s awesome that we can use T.Parallel to auto map the thread binding and do vectorization that still based on the infra from tir schedule transformations, and T.Pipeline from software pipeline, annotate layout from LayoutTransformation Pass).

Anyway, the issue of all these project that I involved is that these projects rely on different versions(or modifications) of TVM, and since the changes were often made as hotfixes to release quickly (some hack may be ugly and inelegant), it is difficult to merge them upstream. One idea I have is that all third-party developers should continue to maintain their own versions of TVM for development, but use a unified IR Module (TIR) and Relax as an interface. However, I encountered some problems while trying to implement this approach, such as conflicts occurring when loading DLLs across different versions of TVM, But I don’t know if that’s a valuable path.

2 Likes

Thanks @LeiWang1999 , I think the main goal here would be to ensure that the IR remain as a common shared parts.

Different projects can have their own defined transformations and leverages the main code-base. That would enable us to reuse different tuner and transformations of the IR out of tree. We have been using such patterns relax passes in LLM compilations, likely we can also use it for TIR compilation in some extent. This is also the main benefit of the modular flow in the new approach

The way it works is to define customized TIR lowering pass

IRModule => My-TIR-Lowering(can be in 3rdparty) => Common TIR lowering pipeline

And ensure that the build process can leverage the common representation. Changes to common IR and unified runtime themselves still needs to happen in upstream, but that is usually less frequent than transformation

1 Like

@tqchen, thanks! This is exactly what we are expecting. However, last time I tried to bring my own tuner into mlc-llm, I encountered an issue:

import tvm  # upstream

relax_mod = relax_transform(relax_mod)

import welder
relax_mod = welder.tune(relax_mod)
# something bad happened

The problem was that when welder is imported, it also imports in its own version of TVM, which then invokes load_dlls (for example, to load libcutlass). This process ends up overwriting the upstream cutlass lib and lead to some bugs.

that is right, in such case, we will need to ensure downstream project structured to depend on the same libtvm. So both projectA, and projectB depends on the same upstream TVM (via include dependency), but also build new optimization transformations on-top.

That does mean we need to restructure the projects instead of simply doing inplace modification, for example, MLC LLM add customized pass and runtime function on top while taking tvm as a dependency.

Our hope is that by updating the upstream APIs to be more modular, such transformations can happen more organically.

LLMs are fundamentally transforming the paradigm of ML deployment and compilation. Simultaneously, the increasing complexity of ML optimization pipelines has rendered many legacy components inadequate for meeting rapidly evolving requirements.

On the other hand, the open-source community faces a shortage of volunteers willing to maintain these codebases consistently. Consequently, we must prioritize and concentrate our efforts on key strategic approaches to address these challenges effectively.

For most common use cases, the Unity flow can effectively replace legacy components, incorporating features such as static shape auto-tuning and BYOC capabilities. While we acknowledge that some niche scenarios (e.g., microTVM) may not be fully supported initially, we can address these later if strong demand persists.

In summary, I concur that the time has come to gradually phase out legacy components. This strategic move will serve two crucial purposes:

  1. Cleanup the codebase: By removing outdated or redundant elements, we can significantly reduce complexity and improve maintainability.

  2. Unify our focus: Concentrating our efforts on the new unity flow will allow for more efficient development and innovation.

3 Likes

One suggestion that I have for TVM is to add a cleaner exit from the stack.

For example, for opencl/ cuda targets, what do I do if I just want the generated kernels?

Note: there is a way to print the source for CL, but unfortunately I have not found a way to get the work group / threadblock sizes and dimensions, which are needed to use the kernels. Surely, those parameters were tuned.

@varunnaw Good point, in my project we use this approach to retrieve attributes, including the dynamic shared memory size and block/grid information (we add these attribute in a tvm pass), which might be helpful to you.

Why this is important?

When users integrate the tvm runtime with 3rdparty frameworks like torch, using dlpack can introduce significant runtime overheads on smaller data shapes, such as gemv and small batched gemv on data-center GPUs. In our benchmarks, we observed delays of around 10 to 50 us. For more details, please refer to this discussion: Strange overhead of tvm.runtime.ndarray.from_dlpack - Apache TVM Discuss.

These overheads arise not only from the ctypes overhead required to initialize a TVMValue from dlpack, but also from occasional calls to CUDASetDevice during the conversion process, which is also cost.

Moreover, when we want to extract the generated code for another usages, tvm doesn’t provide a tool to extract the BlockDim and GridDim and the unified shared memory usage automatically (which can help us to initialize the dynamic shared memory), maybe we can learn a possible solution from the link that I put forward. :slight_smile:

Thanks!

I’m not familiar with this project BitBlas. Please correct me if I am wrong: in the code you showed, the IRModule pass that retrieves the threadblock dimensions is get_annotated_device_mod I’m confused by how the cuda source wrapper is initialized; an IR module plus a source string is passed? don’t you typically get the source after building the module?

Also, do you initialize the TileDevice class with remote.cl() or remote.cuda() just as tvm examples do?

Here’s a python script that prints the source for a single conv2d (I omitted tuning for brevity). I still don’t know how to get work group sizing though. Do you have any advice on how to use your method in BitBlas here?

import numpy as np
import tvm
from tvm import relay, autotvm
import tvm.relay.testing


target_str = "opencl"
target = tvm.target.Target(target_str, host="llvm -mtriple=aarch64-linux-android")
dtype = "float16"
input_name = "input"
filter_name = "weight"

input_shape=(1, 25, 25, 64)
filter_shape=(3, 3, 64, 96)
filter = np.random.rand(*filter_shape).astype(dtype)

input = tvm.relay.var("input", shape=input_shape, dtype=dtype)
weight = tvm.relay.var("weight", shape=filter_shape, dtype=dtype)
D = relay.nn.conv2d(input, weight, padding=(0, 0), data_layout="NHWC", kernel_layout="HWIO", out_dtype=dtype)

mod = relay.Function([input, weight], D)
params = {
   "weight": tvm.nd.array(filter)
}

with tvm.transform.PassContext(opt_level=3):
   graph, lib, params = relay.build_module.build(mod, target, params=params)

print(lib.imported_modules[0].get_source())

Is BYOC the only option for adding graph substitutions? If the substitution is just for one operator, can this be implemented by adding an entirely new operator?

@varunnaw , such an approach only works for single IR modules rather than a end2end module, we modified the pass lower_device_kernel_launch https://github.com/LeiWang1999/tvm/blob/bitblas_tl/src/tir/transforms/lower_device_kernel_launch.cc#L244-L245 to inject these attributes, if you want to extract it from source, I guess we can modify codegen_c to add some extra comments based on these attributes.:

extern "C" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) {
          // thread_extent [32, 2, 2]
          // block_dims [2, 2, 1]
          // dynamic_size_in_bytes [16384]
``

TVM has indeed made progress in supporting GenAI and has also performed well in the mobile. As one of the downstream players in the ecosystem, Arm China relies on TVM’s Parser to interpret various model formats, such as TensorFlow, PyTorch, Caffe, etc. Internally, we have a heavy dependency on Relay and have written some passes to accomplish customized operations.

The AI field is indeed rapidly evolving, and we understand the idea of phasing out legacy components. However, due to our internal dependencies, we have concerns and doubts about the current front-end support for Relax.

  1. From the information we have gathered so far, it seems that Relax may not yet fully replace the front-end capabilities of Relay.
  2. If there is a gradual phase-out of Relay, we would like to know if there is a proven and viable plan for this transition.
1 Like

Here are some considerations that we can take together regarding the front end.

Frontend of interest evolve over time. For example, latest PyTorch frontend migrates to the FX graph, fx and inductor, and the respective frontend needs to be updated accordingly. For new frontend needs. Bringing them to relax would enable a clear focus here. It also unblocks the issue of dynamic shape. So we can have focused effort around these efforts. That is why such conversation is important, so we can enabel the focus.

We can possibly is keep certain importer modules and data structures a bit longer if there is community volunteer effort maintaining them. We need to address the testing issue by moving from execution tests to structural tests and placing execution tests nightly, where the model gets imported and then translated to Relax for structural testing. We encourage such efforts to actually start work on frontend translation directly into Relax when possible.

Coming back to the broader context, this is indeed a hard tradeoff we need to make. As the real impact translates to the volunteer developers, and we can face a real risk being burdened by lack-maintainace, slow development and the project not survive in the fast competitive landscape. That is why it is important to bring this conversation and move toward the direction. That would also enable to have a clear call to focus on some of the latest frontend needs through relax development. Love to see ideas around them and working together on some of the directions!

To keep things continue supported, we should enable release branches cut that can continue to take maintenance patches on the related components. We can also account for them in community development and contributions.

1 Like

Here at Infineon one of the key “deciders” for TVM was the availability of the mature (relay-based) backends for ARM embedded HW (and other COTS targets). Reading between the lines (“release branch, maintenance patch” ) it seems that these will effectively be orphaned. No access to new frontends, little or no scope for active enhancement/extension PRs, loss of connection to the TVM community mainstream for anyone still working with them…

Is there any likelihood of these being ported to Relax (ideally by their contributors)?

Without these TVM would become something of a “non starter” for our productive use. Dependable and properly maintainable backends for the mainstream ARM compute IP is the “must have”. For our own in-house HW we’d just have to grit our teeth, write off our TVM investment, and suffer through hacking-up TFLM/PyTorch edge with in-house performance hacks

Let us look into some of the frontend needs. One thing that we can do is to align most of the relax, relay ops, so we can try to use GenAI tools to bring some of the relay frontend to relax.

I’m currently working on refactoring our project of the methodology we discussed in this thread, using TVM core infrastructure by utilizing tvm with include dependencies and link tvm with shared libraries.

Example of a CMakeList.txt that works with tvm.

cmake_minimum_required(VERSION 3.21)
project(TileLang C CXX)

set(CMAKE_CXX_STANDARD 17)

# Define TVM root directory
set(TVM_ROOT ${PROJECT_SOURCE_DIR}/3rdparty/tvm)

# Include directories
include_directories(
    ${PROJECT_SOURCE_DIR}/include
)

# Source files for the project
file(GLOB_RECURSE TileLang_SOURCES
    ${PROJECT_SOURCE_DIR}/src/transform/*.cpp
    ${PROJECT_SOURCE_DIR}/src/op/*.cc
    ${PROJECT_SOURCE_DIR}/src/codegen/*.cc
)

# Create shared library
add_library(TileLang SHARED ${TileLang_SOURCES})

I think the key part of this pipeline is, ensuring that the tvm based implementation allows developers write their own passes(from the cpp side), I’m not sure how we can still bind our own cpp transformations and op define to python with TVM FFI. Do we have any example projects or guidelines for this? I’ll continue exploring to achieve a cleaner design.

I think it is possible, mlc_llm should serve as an example.

Here are some examples of binding global function https://github.com/mlc-ai/mlc-llm/blob/main/cpp/serve/radix_tree.cc#L822

1 Like

Thanks tianqi, it works for me. However, there are some other issues that may need further discussion:

Some passes require access to functions that are not available in the include directory. For example, files like src/arith/ir_mutator_with_analyzer.h and src/runtime/cuda/cuda_module.h. One workaround might be to add the necessary dependencies directly in the CMakeFile, but I wonder if we have experience or a recommended approach for handling this issue?