[Design] Torchy: Productive Model Definition in TVM Unity

Author: @cyx @leshenj15 @junrushao

Prototype: https://github.com/apache/tvm/pull/15388

Why Torchy?

LLM workloads has brought in significant impact and attracting users to TVM Unity. It enables universal deployment of LLMs on different GPU models, and has been significantly faster than other established open-source solutions, for example, Exllama, llama.cpp.

Engineering Challenge

As compiler engineers using TVM, we find it major obstacles to deliver models coverage on top of existing infrastructure, limiting engineering agility, time-to-market value and eventually business impact. For example:

  • It took an engineering week to implement, and an extra week to debug with, a GPT2-XL variant without any performance optimization, which, is usually expected to take 1-2 engineering hours;
  • Implementation of Falcon and MPT, two top players (sorted by MMLU score) in commercial LLM world, were proposed but was not eventually delivered, and one of the key negative factor is engineering complexity.
  • Multiple key operators, for example, layer normalization used in Dolly and StableLM, are inconsistent in TVM definition across its two layers - tvm.relax and tvm.topi, taking extra time in debugging.

Assessment: Missing Key Features

In our analysis, the engineering inefficiency is mainly due to the following key missing features, making it impractical for even TVM experts to productively implement deep learning workloads in TVM:

  • Model definition is significantly harder and more error-prone in TVM, for example, Llama’s attention takes 203 lines in TVM while it takes around 10 lines in PyTorch einsum.
  • Correctness debugging is significantly more challenging compared with PyTorch. For example, to debug an attention, the following process are strictly required that couple with multiple layers of TVM, including but not limited to: converting PyTorch parameters to TVM Relax, construct IRModule from relax.testing.Module, finding TVM compilation targets, properly feed them together to TVM to build, convert the inputs/outputs in between TVM and Torch. Note that this process does not involve breakpoint-style debugging or printing intermediate values, making debugging less intuitive.
  • Operator consistency. Significant effort was taken to ensure operators are defined consistently across layers, requiring expertise to debug TVM’s different layers including TVMScript, Relax, TOPI, TE and TIR. There is no easy cut for ordinary engineers to properly debug with operators if misuses happen.

Proposed Project and UX Goal

This design doc proposes a small-scale engineering project to systematically address the aforementioned issue by properly categorize and expose a limited subset of APIs, which aims to provide:

  • Identical APIs with PyTorch to build deep learning workloads, particularly LLMs, meaning PyTorch users are not required to understand any concept of TVM to build up LLMs;
  • Seamless integration and debugging experience with PyTorch;
  • Minimize boilerplate code needed and potential engineering inconsistency across layers of TVM stack.

Relation with Existing Efforts

This project is complimentary to all existing efforts across TVM stack:

  • Importers (StableHLO / PyTorch / ONNX): While striving to be as sound as possible, there is practically non-trivial to guarantee that importing consistently works on TVM side upon any upstream change, particularly on representation of dynamic shape workloads, graph breaks, frontend-specific IR constructs that are challenging to optimize away in one shot. This project gives a fall back mechanism to quickly implement a workload once importers fail to secure engineering timeline.
  • TVMScript. TVMScript is the basis of TVM compiler stack which is a roundtrippable text representation of any TVM IR, and this project allows users to easily define and export a TVM IR. In this sense, TVMScript is one lower level than this project in the compilation workflow.
  • Existing relax.testing.nn. This project will replace the existing temporary testing-purposed nn.Module making it really accessible to engineers who are not familiar with TVM to productively define deep learning workloads.

Engineering Goal

We separate the project into three relatively independent projects.

Infra of PyTorch-like Python API

Introduce a small set of key APIs to define deep learning workloads with agility, and the APIs will mimic PyTorch and become as easily debuggable as possible. More specifically,

  • Catering to PyTorch users: The APIs should be identical to PyTorch, and guarantee that PyTorch-defined models, if not come with rare constructs like control flow, alias or inplace mutation, can be copy pasted to this infra with minor modification;
  • Agility: The lines of code (LOC) should be as close as Huggingface’s PyTorch implementation as it is designed to be as close as possible to PyTorch;
  • Debuggability: Allows easy plug-in with PyTorch by a JIT API that converts our nn.Module to a torch-in-torch-out callable method in pure Python.

Infra of Operators across TVM Stack

Any operators defined in tvm.topi, can be directly exposed to tvm.relax and the proposed nn.Module, with proper sanity checking, type checking, shape inference, legalization in both C++ and Python API. More specifically,

  • An operator schema system that defines the input/output type of operators;
  • Automatic generation of structInfo inference and legalization logics based on TOPI definition;
  • C++ and Python API generation and integration with nn.Module.

Delivering LLM Architectures

Deliver the following LLM architectures using the proposed APIs, and each model should be of less than 100 LOC more than Huggingface’s implementation:

  • Llama and Llama-2
  • GPTNeoX (RedPajama, StableLM and Dolly)
  • Falcon

Design

The design of the proposed system comprises of two major components.

Tensor, Parameter and nn.Module

Tensor. A wrapper on top of relax.Expr, providing more convenient access shape and dtype information. Tensor is always symbolc and not bound to any concrete values. It supports dynamic shape out of box based on relax.

class Tensor:
  _expr: relax.Expr
  @property
  def shape(self) -> list[int | tir.Var]: ...
  @property
  def dtype(self) -> str: ... 

Parameter. A parameter is a subclass of Tensor which could be optionally bound to a concrete NDArray.

class Parameter(Tensor):
  _data: Optional[tvm.runtime.NDArray]

  def to(self, dtype: str) -> None: ...
    """Convert to a certain dtype"""

Module. The base class for neural networks. Users are expected to subclass it to build their models. Modules can nest within each other in a tree structure using regular attribute assignment.

class Module:
  def state_dict(self) -> dict[str, Parameter]: ...
    """Collect all parameters"""
  def load_state_dict(self, state_dict: dict[str, Parameter]) -> None: ....
    """Set the value of all parameters"""
  def to(self, dtype: str) -> None: ...
    """Convert all parameters to a certain dtype recursively"""

  def export_to_tvm(self, spec) -> (IRModule, list[(str, Parameter)]): ...
    """Export the nn.Module to TVM's IRModule and returns all its parameters"""
  def jit(self, spec, target: tvm.Target, device: str): ...
    """Convert the nn.Module to a torch-in-torch-out module"""

Effect. A non-user facing class that encloses potential side effects, for example, IO, impure external function callings, inplace mutation, etc.

class Effect:
  def emit_init(self, name: str, builder: relax.BlockBuilder) -> list[relax.DataflowVar]: ...
    """Emit the global effect initialization logic, for example, zeroing KVCache"""
  def create(self, name: str) -> list[relax.Var]: ...
    """Create the effect as a list of input parameters of relax.Function"""
  def finalize(self) -> list[relax.Var]: ...
    """Return the effect variables that should be returned from relax.Function"""

Lowering to TVM IRModule. All parameters and effects will be inserted into explicit inputs/outputs to each relax.Function.

@I.ir_modoule
class TVM_IRModule:
  @R.function
  def relax_func_0(
    arg_0, arg_1,       # inputs specified by users
    effect_0, effect_1, # effect variables
    param_0, param_1,   # parameters
  ) -> (
    out_0, out_1,                  # outputs specified by users
    effect_0_prime, effect_1_prime # effect variables
  ): ...

  @R.function
  def relax_func_1(
    arg_0, arg_1, arg_2, # inputs specified by users
    effect_0, effect_1,  # the same set of effects/parameters
    param_0, param_1,
  ): ...

Operator, Parameter Schemas

Parameter schema. A parameter schema can be used to indicate its type, builtin with type checking and conversion logic in python → relax → TE/TOPI.

class ParamDef:
  name: str
  py_nn_sig: str # Type signature in nn.Module
  py_rx_sig: str # Type signature in relax
  py_rx_ty: str  # Type checking logic in Relax python package
  c_rx_sig: str  # Type signature in Relax C++ package
  c_rx_ty:  str  # Type checking logic in Relax C++ package
  c_rx2te:  str  # Conversion logic from Relax -> TE/TOPI lowering

Built-in schemas are provided for ordinary types, for example,

class Tensor(ParamDef): ...
class Int(ParamDef): ...
class Float(ParamDef): ...
class Array(ParamDef): ...

Operator schema. An operator has input parameters and outputs in its schema. Besides, it has attributes that allow effective fallback mechanisms if the operator schema infra doesn’t generate the desirable struct_info inference or legalization out-of-box.

class OpSchema:
  name: str
  # parameters
  params: list[ParamDef]
  ret: list[ParamDef]
  # definition in TOPI
  topi_func: str
  # fallback mechanisms for struct-info inference
  sinfo_override: str
  sinfo_fallback: str
  # fallback mechanisms for legalization (relax -> tir lowering)
  lower_override: str
  lower_fallback: str
  # extra attributes
  attrs: list[(str, str)]

Demo: Llama attention in 10 lines

We are able to implement Llama attention within 10 lines of code plus parameter declaration.

Parameter declaration. The following lines define the attention architecture, including QKV projection, output projection and the stateful KVCache.

class LlamaAttention(Module):
  def __init__(self, config: LlamaConfig, rotary_embedding: RotaryEmbedding):
    head_dim = config.hidden_size // config.num_attention_heads
    self.hidden_size = config.hidden_size
    self.num_attention_heads = config.num_attention_heads
    self.rotary_embedding = rotary_embedding
    self.q_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
    self.k_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
    self.v_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
    self.o_proj = Linear(config.hidden_size, config.hidden_size, bias=False)
    self.k_cache = KVCache(config.max_sequence_length, [config.num_attention_heads, head_dim])
    self.v_cache = KVCache(config.max_sequence_length, [config.num_attention_heads, head_dim])

Computation. The following lines define the computation in attention architecture.

class LlamaAttention(Module):  # pylint: disable=too-many-instance-attributes
  ...
  def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var):
    (h, d), (b, s, _), t = (self.num_attention_heads, self.hidden_size // self.num_attention_heads), hidden_states.shape, total_seq_len
    q = reshape(self.q_proj(hidden_states), (b, s, h, d))
    k = reshape(self.k_proj(hidden_states), (b, s, h, d))
    v = reshape(self.v_proj(hidden_states), (b, s, h, d))
    q, k = self.rotary_embedding(q, k, t - s)
    k = self.k_cache.append(squeeze(k, axis=0), (t, b, h, d))
    v = self.v_cache.append(squeeze(v, axis=0), (t, b, h, d))
    attn_weights = matmul(q.permute_dims([0, 2, 1, 3]), k.permute_dims([1, 2, 3, 0]),) / math.sqrt(d)
    attn_weights = attn_weights.maximum(tir.min_value(attn_weights.dtype)).minimum(attention_mask).softmax(axis=-1)
    attn_output = self.o_proj(matmul(attn_weights, v.permute_dims([1, 2, 0, 3])).permute_dims([0, 2, 1, 3]).reshape((b, s, h * d)))
    return attn_output
16 Likes

A full-featured Llama2 implementation in only 200 lines of code based on this project: https://github.com/mlc-ai/mlc-llm/pull/631

1 Like