[meta-RFC] Vector length agnostic (VLA) vectorization

This Meta-RFC proposes steps to enable VLA support in TVM by first implementing SVE support in codegen and then adding a generic VLA support into TIR.

For an explanation and examples of SVE, see the introduction to SVE2 and this tutorial.

Note that the approach taken here is different to the one outlined in the previous RFC and accompanying prototype, so this set of RFCs here is intended to replace the previous one.

We’ve broken the work down into two steps:

RFC1: Introduction of CodeGenAArch64 TIR backend

Link: https://github.com/apache/tvm-rfcs/pull/94

As a first step, we’ll be introducing a new TIR backend to generate AArch64 SVE specific LLVM IR.

We would add an option to the target to indicate that the inner loops in TIR with “vectorize” annotation should be turned into VLA LLVM vectors using the AArch64 backend and then into SVE assembly. In the future, if the LLVM IR generated for VLA is generic enough and support for VLA has matured, we could move it into CodeGenLLVM.

In terms of adding a new backend to do something quite generic, there are a few reasons to have to support initially in AArch64 backend. LLVM support for VLA is known to have several issues (surfaced by the work done in introducing SVE into Halide), so an AArch64 specific backend would be a good place to work around these issues until the VLA support matures. In addition, there is a pipeline of new features being added to AArch64 that would need (at least initial) special handling through AArch64 specific intrinsics.

In contrast to the old RFC we have chosen to start from adding the support to the backend since it would also allow us to pipeclean the backend and have more confidence prior to more invasive changes in other parts of the TVM stack.

RFC2: Vector length agnostic intrinsics and predication support in TVM

Link: (coming soon!)

As a next step, we will introduce two TIR intrinsics:

  • tir.vscale() - a compile time unknown variable representing the vector length
  • tir.get_active_lane_mask() - an intrinsic to generate the predicate

These intrinsics would allow us to generate code as shown in the following example:

@main = primfn(A_1: handle, B_1: handle, m: int32, n: int32) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [200, 200], []),
             B: Buffer(B_2: Pointer(float32), float32, [200, 200], [])}
  buffer_map = {A_1: A, B_1: B} {
  realize(compute: Buffer(compute_1: Pointer(float32), float32, [200, 200], []), [0:200, 0:200], True {
    for (i.outer: int32, 0, 20) {
      for (j.outer: int32, 0, (floordiv(199, tir.vscale() * 4) + 1)) {
        for (i.inner: int32, 0, 10) "unroll" {
          let pred = tir.get_active_lane_mask((j.outer * tir.vscale() * 4), 200)
          compute[(i.inner + (i.outer*10)), ramp((j.outer* tir.vscale() * 4), 1, tir.vscale() * 4), predicate=pred] = (A[(i.inner + (i.outer*10)), ramp((j.outer * tir.vscale() * 4), 1, tir.vscale() * 4), predicate=pred] + B[(i.inner + (i.outer*10)), ramp((j.outer * tir.vscale() * 4), 1, tir.vscale() * 4), predicate=pred])
        }
      }
    }
  })
}

These new intrinsics would allow us to do two things:

  1. tir.vscale() to be used as an argument for splitting / tiling loops
  2. tir.get_active_lane_mask() to be used to generate a predicate using the loop induction variables and the bounds to generate a bit mask representing the predicate

These would then get lowered to llvm.vscale() and llvm.get_active_lane_mask() intrinsics along llvm.masked.* instructions for predicate support.

Why do we need to introduce changes to TIR if we have a specialised backend?

Constraining the support to backend only would have several limitations:

  1. We can’t use this approach for instructions that rely on tensorize, such as the SVE equivalent of MMLA and dot instructions.
    • We need the above TIR instrinsics to define the tenorization pattern with scalable vectors
    • Usually, more than one loop-nest is needed (i.e. not just innermost) to define the pattern
  2. We wouldn’t be able to let the tuner decide between scalable and fixed length vectors.
6 Likes