## Introduction and motivation

This RFC is the third set of optimizations to enhance quantized convolution on Arm architectures. To give a brief summary:

- Basic Armv8-A convolution implementation (through gemm): [RFC] Improve quantized convolution performance for armv8 architectures
- Dot product enhancements for Armv8.2-A: [RFC] Accelerate quantized convolution through dot-product

In this RFC we will support the Matrix Multiply Accumulate instruction. : https://developer.arm.com/docs/ddi0602/g/simd-and-floating-point-instructions-alphabetic-order/smmla-vector-signed-8-bit-integer-matrix-multiply-accumulate-vector

This instruction is optionally supported from Armv8.2-A onward, while it is mandatory from Armv8.6-A onward.

## Overview of the Matrix Multiply Instruction

Let’s have a brief introduction of how `smmla`

works. While we will add support for `ummla`

as well, for the remaining of this RFC we will only mention its signed version.

The following picture briefly shows how `smmla`

works:

In the picture, `vec_a`

and `vec_b`

are `int8x16`

(or `.16b`

) registers while `vec_c`

is a `int32x4`

(or `.4s`

) register. You may notice that this is quite different from dot-product. In dot-product we compute a `1x4`

sub-row of the output tile. With `smmla`

, we compute a 2D `2x2`

sub-tile of the final result. This will need proper handling during the tiling phase of the algorithm

## GEMM implementation through `smmla`

Now that we have enough understanding of how `smmla`

works, we can discuss how we decided to add support for it.

Let’s reiterate once more how the general GEMM algorithm (`C=A*B`

) works:

- Subdivide matrix
`A`

in adjacent tiles (a process called packing or interleaving) - Interleave and transpose matrix
`B`

(this can be done offline) - Run GEMM to produce a
`C_interleaved`

version of the output - Unpack
`C_interleaved`

to obtain`C`

The tiling is usually chosen to maximize the register utilization. In this case, the best tiling to exploit register allocation is still a `8x12`

tile (like the one we used for dot-product). The difficulty with `smmla`

is that `C_interleaved`

(i.e., the tiled version of the output) will be composed by `2x2`

sub-tiles that need to be extracted when unpacking. The following picture tries to show the situation:

Please note that `A_tile`

and `B_tile`

(and thus `C_tile`

as well) are shown in their native layout. The problem is that the four elements of each sub-tile generated (e.g., the red `2x2`

sub-tile) are contiguous in memory. This needs to be expressed at compute level in order to then unpack properly.

### The compute node

Given what stated above, the final compute node to produce `C_interleaved`

and unpack the result is the following:

```
C_interleaved = te.compute((batches, M_padded // tile_rows_A, N_transformed, 4, 6, 2, 2),
lambda b, x, y, w, z, s, t: te.sum(
A_interleaved[b, x, k // tile_cols_A, 2 * w + s, idxm(k, tile_cols_A)].astype("int32") *
B_interleaved_t[y, k // tile_cols_B, 2 * z + t, idxm(k, tile_cols_B)].astype("int32"),
axis=k,),
name="C_interleaved",)
C = te.compute((batches, M, N),
lambda b, x, y: C_interleaved[
b,
x // tile_rows_A,
y // tile_rows_B,
idxm(x, tile_rows_A) // 2,
idxm(y, tile_rows_B) // 2,
idxm(idxm(x, tile_rows_A), 2),
idxm(idxm(y, tile_rows_B), 2),
].astype(out_dtype),
name="C",)
```

We are simply expressing during the computation that the output `8x12`

tile is really a `4x6x2x2`

tile (i.e., it is a `8x12`

tile, composed by `2x2`

sub-tiles). This is also taken into account when we are unpacking `C_interleaved`

.

### The tensorization rule

Once we have our `C_interleaved`

in the right form, it is fairly simple to tensorize on the inner `2x2`

sub-tile and unroll the outer dimensions to have a final `8x12`

output tile.

The following snippet of code is an extract of the tensorization rule we are using

```
vec_a = ins[0].vload([0, 0], dtype_vec)
# Load in vec_b the two rows of B
# vec_b = [0, 2, 4, 6, 8, 10, 12, 14;
# 1, 3, 5, 7, 9, 11, 13, 14,]
vec_b = ins[1].vload([0, 0], dtype_vec)
# Execute the matrix multiplication via (s/u)mmla:
# vec_c = [a*0 + b*2 + c*4 + d*6 +e*8 + f*10 + g*12 + h*14;
# a*1 + b*3 + c*5 + d*7 +e*9 + f*11 + g*13 + h*15;
# i*0 + j*2 + k*4 + l*6 +m*8 + n*10 + o*12 + p*14;
# i*1 + j*3 + k*5 + l*7 +m*9 + n*11 + o*13 + p*15]
vec_c = outs[0].vload([0, 0], "int32x4")
vmmla = tvm.tir.call_llvm_intrin(
"int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_a, vec_b,
)
# Store the result
ib.emit(outs[0].vstore([0, 0], vmmla))
```

This is very close to the previous picture where we showed the `smmla`

functioning. It is also instructive to show how we use the intrinsic within the schedule:

```
mmla = mmla_2x2_int8_int8_int32(in_type)
xi_inner, yi_inner = C_interleaved.op.axis[5:7]
k_outer, k_inner = s[C_interleaved].split(k, 8)
s[C_interleaved].reorder(
b_outer_gemm_fused, inner_gemm, k_outer, xi, yi, xi_inner, yi_inner, k_inner
)
s[C_interleaved].tensorize(xi_inner, mmla)
s[C_interleaved].unroll(xi)
s[C_interleaved].unroll(yi)
```

Other then splitting the reduction axis in 8 (i.e., the reduction dimension of `mmla`

), we can simply apply the intrinsic and unroll the outer tile dimensions.

### Testing and performance

While this instruction is not yet supported by real hardware, at Arm we have been able to run and verify this instruction on internal cycle-accurate simulator. We saw for instance, on a `[75x75x80] * [3x3x80x192]`

convolution (heaviest layer from `inceptionV3`

), a 35% improvement compared to the dot-product implementation.

### PR

The PR for this RFC is available here: https://github.com/apache/incubator-tvm/pull/6802