Foundation models are important workloads, and by pushing the local/server inference of LLMs to the extreme in the TVM stack, I believe we can push the resolution of pain points to a new stage for people to use TVM as THE deep learning compiler in general scenarios, which is necessary for us to keep competitive (and alive) and be able to provide productivity continuously.
To name some
1. Out-of-box kernel performance
Triton has gradually become the solution for people writing custom ops on CUDA. TorchInductor uses Triton to generate codes for long-tail operators (reduction/elementwise) by default and will try to take Triton GEMM into search space if more tuning is turned on.
LLM typically has 2 phases, prefill and decoding, where prefill is bound by GEMMs and decoding is bound by (quantized) GEMVs, which are representative workloads for GEMM ops and long-tail ops respectively. PT team tried to use TVM as its backend, but MetaSchedule takes too long to tune and generates sub-optimal programs. The resulting performance is poor given a certain amount of compilation time.
We see an interesting future for other less popular backends, especially when it’s equipped with unified memory which allows larger models to run. It’s necessary to have knowledge of the kernel scheduling on these backends and to fully leverage the advantage of TVM to transfer schedules between different backends.
2. More advanced operator fusion
TVM has generally only used vertical fusion at the graph level. We do see horizontal fusion (3GEMM → larger GEMM) to play an effect in LLM models.
Meanwhile, to be able to fuse GEMM (Conv) → (norm) → GEMM (Conv) operators and generate efficient code can be important for attention ops. Instead of relying on CUTLASS/FT on CUDA, we can transfer such patterns to other backends.
There have been lines of research around the fusion algorithms in DL compilers, most of which are search based. Here we focus only on default patterns that generally work across workloads, but we have already known there are missing pieces for us to do.
3. Distributed inference
No matter how good the quantization scheme and memory planning algorithms are, people are always greedy for the size of the model they can run. Even with 4bit quantization, 70B Llama2 still enquires 2x 40GB A100 to simply hold the model.
Instead of swapping the model between much slower storage and GPU memory, unified memory provides an interesting solution here (64GB mbp can serve 70B Llama2 alone).
Another orthogonal solution is distributed inference, and they can be combined together.
4. More hackable infrastructure and fewer new concepts in general
We have seen some recent DL compilers written in purely Python (TorchInductor, Hidet), which provides a much easier debug and hack experience for engineers.
Also, we have seen projects like llama.cpp and llama.c, which use purely cpp/c to implement the whole model and kernels. People are actively contributing to it and I believe one important reason is that it’s straightforward to understand by reading through its code, and people who have little knowledge of how DL compiler generally works can hack and debug into the infra.
To be able to import and modify the model, insert new layers, substitute generated kernels with different implementations in shader language, and change the operator fusion in a smooth manner as people expect in their mind, while at the same time providing reasonable debugging tools like setting breakpoints, inspecting intermediate outputs (in Python) as people has always been doing since their first day start to learn to program, can enable more people to come to contribute to our stack.