 # [RFC]: Improve quantized convolution through mmla instruction

## Introduction and motivation

This RFC is the third set of optimizations to enhance quantized convolution on Arm architectures. To give a brief summary:

In this RFC we will support the Matrix Multiply Accumulate instruction. : https://developer.arm.com/docs/ddi0602/g/simd-and-floating-point-instructions-alphabetic-order/smmla-vector-signed-8-bit-integer-matrix-multiply-accumulate-vector

This instruction is optionally supported from Armv8.2-A onward, while it is mandatory from Armv8.6-A onward.

## Overview of the Matrix Multiply Instruction

Let’s have a brief introduction of how `smmla` works. While we will add support for `ummla` as well, for the remaining of this RFC we will only mention its signed version.

The following picture briefly shows how `smmla` works:

In the picture, `vec_a` and `vec_b` are `int8x16` (or `.16b`) registers while `vec_c` is a `int32x4` (or `.4s`) register. You may notice that this is quite different from dot-product. In dot-product we compute a `1x4` sub-row of the output tile. With `smmla`, we compute a 2D `2x2` sub-tile of the final result. This will need proper handling during the tiling phase of the algorithm

## GEMM implementation through `smmla`

Now that we have enough understanding of how `smmla` works, we can discuss how we decided to add support for it.

Let’s reiterate once more how the general GEMM algorithm (`C=A*B`) works:

1. Subdivide matrix `A` in adjacent tiles (a process called packing or interleaving)
2. Interleave and transpose matrix `B` (this can be done offline)
3. Run GEMM to produce a `C_interleaved` version of the output
4. Unpack `C_interleaved` to obtain `C`

The tiling is usually chosen to maximize the register utilization. In this case, the best tiling to exploit register allocation is still a `8x12` tile (like the one we used for dot-product). The difficulty with `smmla` is that `C_interleaved` (i.e., the tiled version of the output) will be composed by `2x2` sub-tiles that need to be extracted when unpacking. The following picture tries to show the situation:

Please note that `A_tile` and `B_tile` (and thus `C_tile` as well) are shown in their native layout. The problem is that the four elements of each sub-tile generated (e.g., the red `2x2` sub-tile) are contiguous in memory. This needs to be expressed at compute level in order to then unpack properly.

### The compute node

Given what stated above, the final compute node to produce `C_interleaved` and unpack the result is the following:

``````C_interleaved = te.compute((batches, M_padded // tile_rows_A, N_transformed, 4, 6, 2, 2),
lambda b, x, y, w, z, s, t: te.sum(
A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype("int32") *
B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype("int32"),
axis=k,),
name="C_interleaved",)

C = te.compute((batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A) // 2,
idxm(y, tile_rows_B) // 2,
idxm(idxm(x, tile_rows_A), 2),
idxm(idxm(y, tile_rows_B), 2),
].astype(out_dtype),
name="C",)
``````

We are simply expressing during the computation that the output `8x12` tile is really a `4x6x2x2` tile (i.e., it is a `8x12` tile, composed by `2x2` sub-tiles). This is also taken into account when we are unpacking `C_interleaved`.

### The tensorization rule

Once we have our `C_interleaved` in the right form, it is fairly simple to tensorize on the inner `2x2` sub-tile and unroll the outer dimensions to have a final `8x12` output tile.

The following snippet of code is an extract of the tensorization rule we are using

``````
# Load in vec_b the two rows of B
# vec_b = [0, 2, 4, 6, 8, 10, 12, 14;
# 1, 3, 5, 7, 9, 11, 13, 14,]

# Execute the matrix multiplication via (s/u)mmla:
# vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14;
# a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15;
# i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14;
# i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15]
vmmla = tvm.tir.call_llvm_intrin(
"int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b,
)
# Store the result
ib.emit(outs.vstore([0, 0], vmmla))

``````

This is very close to the previous picture where we showed the `smmla` functioning. It is also instructive to show how we use the intrinsic within the schedule:

``````
mmla = mmla_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[5:7]
k_outer, k_inner = s[C_interleaved].split(k, 8)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
s[C_interleaved].tensorize(xi_inner, mmla)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)

``````

Other then splitting the reduction axis in 8 (i.e., the reduction dimension of `mmla`), we can simply apply the intrinsic and unroll the outer tile dimensions.

### Testing and performance

While this instruction is not yet supported by real hardware, at Arm we have been able to run and verify this instruction on internal cycle-accurate simulator. We saw for instance, on a `[75x75x80] * [3x3x80x192]` convolution (heaviest layer from `inceptionV3`), a 35% improvement compared to the dot-product implementation.

### PR

The PR for this RFC is available here: https://github.com/apache/incubator-tvm/pull/6802

6 Likes
1 Like

Appreciated the detailed explanation! Regarding the tile size selection, is there any specific reason that we use a 8x12 register allocation strategy? Why not use something like 8 by 8 or 4 by 4 ? @giuseros @anijain2305