Author: @cyx @leshenj15 @junrushao
Prototype: https://github.com/apache/tvm/pull/15388
Why Torchy?
LLM workloads has brought in significant impact and attracting users to TVM Unity. It enables universal deployment of LLMs on different GPU models, and has been significantly faster than other established open-source solutions, for example, Exllama, llama.cpp.
Engineering Challenge
As compiler engineers using TVM, we find it major obstacles to deliver models coverage on top of existing infrastructure, limiting engineering agility, time-to-market value and eventually business impact. For example:
- It took an engineering week to implement, and an extra week to debug with, a GPT2-XL variant without any performance optimization, which, is usually expected to take 1-2 engineering hours;
- Implementation of Falcon and MPT, two top players (sorted by MMLU score) in commercial LLM world, were proposed but was not eventually delivered, and one of the key negative factor is engineering complexity.
- Multiple key operators, for example, layer normalization used in Dolly and StableLM, are inconsistent in TVM definition across its two layers - tvm.relax and tvm.topi, taking extra time in debugging.
Assessment: Missing Key Features
In our analysis, the engineering inefficiency is mainly due to the following key missing features, making it impractical for even TVM experts to productively implement deep learning workloads in TVM:
- Model definition is significantly harder and more error-prone in TVM, for example, Llama’s attention takes 203 lines in TVM while it takes around 10 lines in PyTorch einsum.
-
Correctness debugging is significantly more challenging compared with PyTorch. For example, to debug an attention, the following process are strictly required that couple with multiple layers of TVM, including but not limited to: converting PyTorch parameters to TVM Relax, construct IRModule from
relax.testing.Module
, finding TVM compilation targets, properly feed them together to TVM to build, convert the inputs/outputs in between TVM and Torch. Note that this process does not involve breakpoint-style debugging or printing intermediate values, making debugging less intuitive. - Operator consistency. Significant effort was taken to ensure operators are defined consistently across layers, requiring expertise to debug TVM’s different layers including TVMScript, Relax, TOPI, TE and TIR. There is no easy cut for ordinary engineers to properly debug with operators if misuses happen.
Proposed Project and UX Goal
This design doc proposes a small-scale engineering project to systematically address the aforementioned issue by properly categorize and expose a limited subset of APIs, which aims to provide:
- Identical APIs with PyTorch to build deep learning workloads, particularly LLMs, meaning PyTorch users are not required to understand any concept of TVM to build up LLMs;
- Seamless integration and debugging experience with PyTorch;
- Minimize boilerplate code needed and potential engineering inconsistency across layers of TVM stack.
Relation with Existing Efforts
This project is complimentary to all existing efforts across TVM stack:
- Importers (StableHLO / PyTorch / ONNX): While striving to be as sound as possible, there is practically non-trivial to guarantee that importing consistently works on TVM side upon any upstream change, particularly on representation of dynamic shape workloads, graph breaks, frontend-specific IR constructs that are challenging to optimize away in one shot. This project gives a fall back mechanism to quickly implement a workload once importers fail to secure engineering timeline.
- TVMScript. TVMScript is the basis of TVM compiler stack which is a roundtrippable text representation of any TVM IR, and this project allows users to easily define and export a TVM IR. In this sense, TVMScript is one lower level than this project in the compilation workflow.
- Existing
relax.testing.nn
. This project will replace the existing temporary testing-purposednn.Module
making it really accessible to engineers who are not familiar with TVM to productively define deep learning workloads.
Engineering Goal
We separate the project into three relatively independent projects.
Infra of PyTorch-like Python API
Introduce a small set of key APIs to define deep learning workloads with agility, and the APIs will mimic PyTorch and become as easily debuggable as possible. More specifically,
- Catering to PyTorch users: The APIs should be identical to PyTorch, and guarantee that PyTorch-defined models, if not come with rare constructs like control flow, alias or inplace mutation, can be copy pasted to this infra with minor modification;
- Agility: The lines of code (LOC) should be as close as Huggingface’s PyTorch implementation as it is designed to be as close as possible to PyTorch;
- Debuggability: Allows easy plug-in with PyTorch by a JIT API that converts our nn.Module to a torch-in-torch-out callable method in pure Python.
Infra of Operators across TVM Stack
Any operators defined in tvm.topi
, can be directly exposed to tvm.relax and the proposed nn.Module, with proper sanity checking, type checking, shape inference, legalization in both C++ and Python API. More specifically,
- An operator schema system that defines the input/output type of operators;
- Automatic generation of structInfo inference and legalization logics based on TOPI definition;
- C++ and Python API generation and integration with
nn.Module
.
Delivering LLM Architectures
Deliver the following LLM architectures using the proposed APIs, and each model should be of less than 100 LOC more than Huggingface’s implementation:
- Llama and Llama-2
- GPTNeoX (RedPajama, StableLM and Dolly)
- Falcon
Design
The design of the proposed system comprises of two major components.
Tensor, Parameter and nn.Module
Tensor. A wrapper on top of relax.Expr
, providing more convenient access shape and dtype information. Tensor is always symbolc and not bound to any concrete values. It supports dynamic shape out of box based on relax.
class Tensor:
_expr: relax.Expr
@property
def shape(self) -> list[int | tir.Var]: ...
@property
def dtype(self) -> str: ...
Parameter. A parameter is a subclass of Tensor which could be optionally bound to a concrete NDArray.
class Parameter(Tensor):
_data: Optional[tvm.runtime.NDArray]
def to(self, dtype: str) -> None: ...
"""Convert to a certain dtype"""
Module. The base class for neural networks. Users are expected to subclass it to build their models. Modules can nest within each other in a tree structure using regular attribute assignment.
class Module:
def state_dict(self) -> dict[str, Parameter]: ...
"""Collect all parameters"""
def load_state_dict(self, state_dict: dict[str, Parameter]) -> None: ....
"""Set the value of all parameters"""
def to(self, dtype: str) -> None: ...
"""Convert all parameters to a certain dtype recursively"""
def export_to_tvm(self, spec) -> (IRModule, list[(str, Parameter)]): ...
"""Export the nn.Module to TVM's IRModule and returns all its parameters"""
def jit(self, spec, target: tvm.Target, device: str): ...
"""Convert the nn.Module to a torch-in-torch-out module"""
Effect. A non-user facing class that encloses potential side effects, for example, IO, impure external function callings, inplace mutation, etc.
class Effect:
def emit_init(self, name: str, builder: relax.BlockBuilder) -> list[relax.DataflowVar]: ...
"""Emit the global effect initialization logic, for example, zeroing KVCache"""
def create(self, name: str) -> list[relax.Var]: ...
"""Create the effect as a list of input parameters of relax.Function"""
def finalize(self) -> list[relax.Var]: ...
"""Return the effect variables that should be returned from relax.Function"""
Lowering to TVM IRModule. All parameters and effects will be inserted into explicit inputs/outputs to each relax.Function.
@I.ir_modoule
class TVM_IRModule:
@R.function
def relax_func_0(
arg_0, arg_1, # inputs specified by users
effect_0, effect_1, # effect variables
param_0, param_1, # parameters
) -> (
out_0, out_1, # outputs specified by users
effect_0_prime, effect_1_prime # effect variables
): ...
@R.function
def relax_func_1(
arg_0, arg_1, arg_2, # inputs specified by users
effect_0, effect_1, # the same set of effects/parameters
param_0, param_1,
): ...
Operator, Parameter Schemas
Parameter schema. A parameter schema can be used to indicate its type, builtin with type checking and conversion logic in python → relax → TE/TOPI.
class ParamDef:
name: str
py_nn_sig: str # Type signature in nn.Module
py_rx_sig: str # Type signature in relax
py_rx_ty: str # Type checking logic in Relax python package
c_rx_sig: str # Type signature in Relax C++ package
c_rx_ty: str # Type checking logic in Relax C++ package
c_rx2te: str # Conversion logic from Relax -> TE/TOPI lowering
Built-in schemas are provided for ordinary types, for example,
class Tensor(ParamDef): ...
class Int(ParamDef): ...
class Float(ParamDef): ...
class Array(ParamDef): ...
Operator schema. An operator has input parameters and outputs in its schema. Besides, it has attributes that allow effective fallback mechanisms if the operator schema infra doesn’t generate the desirable struct_info inference or legalization out-of-box.
class OpSchema:
name: str
# parameters
params: list[ParamDef]
ret: list[ParamDef]
# definition in TOPI
topi_func: str
# fallback mechanisms for struct-info inference
sinfo_override: str
sinfo_fallback: str
# fallback mechanisms for legalization (relax -> tir lowering)
lower_override: str
lower_fallback: str
# extra attributes
attrs: list[(str, str)]
Demo: Llama attention in 10 lines
We are able to implement Llama attention within 10 lines of code plus parameter declaration.
Parameter declaration. The following lines define the attention architecture, including QKV projection, output projection and the stateful KVCache.
class LlamaAttention(Module):
def __init__(self, config: LlamaConfig, rotary_embedding: RotaryEmbedding):
head_dim = config.hidden_size // config.num_attention_heads
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.rotary_embedding = rotary_embedding
self.q_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
self.v_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
self.o_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_cache = KVCache(config.max_sequence_length, [config.num_attention_heads, head_dim])
self.v_cache = KVCache(config.max_sequence_length, [config.num_attention_heads, head_dim])
Computation. The following lines define the computation in attention architecture.
class LlamaAttention(Module): # pylint: disable=too-many-instance-attributes
...
def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var):
(h, d), (b, s, _), t = (self.num_attention_heads, self.hidden_size // self.num_attention_heads), hidden_states.shape, total_seq_len
q = reshape(self.q_proj(hidden_states), (b, s, h, d))
k = reshape(self.k_proj(hidden_states), (b, s, h, d))
v = reshape(self.v_proj(hidden_states), (b, s, h, d))
q, k = self.rotary_embedding(q, k, t - s)
k = self.k_cache.append(squeeze(k, axis=0), (t, b, h, d))
v = self.v_cache.append(squeeze(v, axis=0), (t, b, h, d))
attn_weights = matmul(q.permute_dims([0, 2, 1, 3]), k.permute_dims([1, 2, 3, 0]),) / math.sqrt(d)
attn_weights = attn_weights.maximum(tir.min_value(attn_weights.dtype)).minimum(attention_mask).softmax(axis=-1)
attn_output = self.o_proj(matmul(attn_weights, v.permute_dims([1, 2, 0, 3])).permute_dims([0, 2, 1, 3]).reshape((b, s, h * d)))
return attn_output