Reproducing old VTA schedule

Hello,

TL;DR: I cannot make the tensorization part of the old Simple Matrix Multiply tutorial for VTA work in modern TVM.

I am trying to bring back support for the VTA by porting the old VTA specific code to the modern version of TVM. So far I have: ported the graph_pack function to Relax, ported many of the old VTA specific transformations for TensorIR to its modern equivalent and managed to reproduce the Get Started with VTA tutorial (i.e. I can offload the supported element-wise operations to VTA).

Now I am in the process of reproducing the Simple Matrix Multiply tutorial, but I am stuck. Bellow there is what I think what the equivalent code should be (with the annotations removed, the intrinsic mocked and the memory scopes all set to global) to the one in the tutorial.

from tvm import ir
from tvm import tir
from tvm import te
from tvm.script import tir as T

@T.prim_func
def gemm_desc(
        local_wgt_buffer: T.Buffer((16, 16), "int8"),
        local_inp_buffer: T.Buffer((1, 16), "int8"),
        out: T.Buffer((1, 16), "int32")
    ) -> None:
    T.func_attr({"tir.noalias": T.bool(True)})
    with T.block("root"):
        T.reads(out[0:1, 0:16], local_inp_buffer[0:1, 0:16], local_wgt_buffer[0:16, 0:16])
        T.writes(out[0:1, 0:16])
        for i, j, k in T.grid(1, 16, 16):
            with T.block("out"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(out[v_i, v_j], local_inp_buffer[v_i, v_k], local_wgt_buffer[v_j, v_k])
                T.writes(out[v_i, v_j])
                out[v_i, v_j] += T.Cast("int32", local_inp_buffer[v_i, v_k]) \
                    * T.Cast("int32", local_wgt_buffer[v_j, v_k])

@T.prim_func
def gemm_intrin(
        local_wgt_buffer: T.Buffer((16, 16), "int8"),
        local_inp_buffer: T.Buffer((1, 16), "int8"),
        out: T.Buffer((1, 16), "int32")
    ) -> None:
    T.func_attr({"tir.noalias": T.bool(True)})
    with T.block("root"):
        T.reads(out[0:1, 0:16], local_inp_buffer[0:1, 0:16], local_wgt_buffer[0:16, 0:16])
        T.writes(out[0:1, 0:16])
        T.evaluate(T.call_extern("int32", "VTADoTheThing",
            out.data, local_inp_buffer.data, local_wgt_buffer.data))

tir.TensorIntrin.register("gemm_intrin", gemm_desc, gemm_intrin)

BATCH, BLOCK_IN, BLOCK_OUT = 1, 16, 16
O, N, M = 1, 256, 256
o, n, m = O//BATCH, N//BLOCK_IN, M//BLOCK_OUT
out_shape = (o, m, BATCH, BLOCK_OUT)

K = te.reduce_axis((0, n), name="K")
k = te.reduce_axis((0, BLOCK_IN), name="k")

A = te.placeholder((o, n, BATCH, BLOCK_IN), name="A", dtype="int8")
B = te.placeholder((m, n, BLOCK_OUT, BLOCK_IN), name="B", dtype="int8")
C = te.compute(out_shape,
    lambda I, J, i, j: te.sum(
        A[I, K, i, k].astype("int32") * B[J, K, j, k].astype("int32"),
        axis=[K, k],
    ),
    name="C")
D = te.compute(out_shape, lambda I, J, i, j: C(I, J, i, j).astype("int8"), name="D")

gemm = te.create_prim_func([A, B, D]).with_attr({"global_symbol": "gemm"})
mod = ir.IRModule({"gemm": gemm})

sch = tir.Schedule(mod)
sch.work_on('gemm')
C_block = sch.get_block("C")
A_cache = sch.reindex_cache_read(C_block, 0, "global", lambda I, J, i, j, K, k: (I, K, i, k))
B_cache = sch.reindex_cache_read(C_block, 1, "global", lambda I, J, i, j, K, k: (J, K, j, k))
I, J, i, j, K, k = sch.get_loops(C_block)
sch.compute_at(A_cache, K)
sch.compute_at(B_cache, K)
sch.decompose_reduction(C_block, K)
sch.reorder_block_iter_var(C_block, (4, 0, 1, 2, 3, 5))
sch.mod.show()
sch.tensorize(C_block, "gemm_intrin")

The code above fails on the last line with the following error.

tvm.tir.schedule.schedule.ScheduleError: ScheduleError: An error occurred in the schedule primitive 'tensorize'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def gemm(var_A: T.handle, var_B: T.handle, var_D: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        A = T.match_buffer(var_A, (1, 16, 1, 16), "int8")
        B = T.match_buffer(var_B, (16, 16, 16, 16), "int8")
        D = T.match_buffer(var_D, (1, 16, 1, 16), "int8")
        with T.block("root"):
            T.reads()
            T.writes()
            C = T.alloc_buffer((1, 16, 1, 16), "int32")
            A_global = T.alloc_buffer((1, 16, 1, 16), "int8")
            B_global = T.alloc_buffer((16, 16, 16, 16), "int8")
            for I_1 in range(1):
                for J in range(16):
                    for i in range(1):
                        for j in range(16):
                            with T.block("C_init"):
                                v_I = T.axis.spatial(1, I_1)
                                v_J = T.axis.spatial(16, J)
                                v_i = T.axis.spatial(1, i)
                                v_j = T.axis.spatial(16, j)
                                T.reads()
                                T.writes(C[v_I, v_J, v_i, v_j])
                                C[v_I, v_J, v_i, v_j] = 0
                            for K in range(16):
                                for ax0 in range(16):
                                    with T.block("A_global"):
                                        v_I = T.axis.spatial(1, 0)
                                        v_i = T.axis.spatial(1, 0)
                                        v_K = T.axis.spatial(16, K)
                                        v_k = T.axis.spatial(16, ax0)
                                        T.reads(A[v_I, v_K, v_i, v_k])
                                        T.writes(A_global[v_I, v_K, v_i, v_k])
                                        A_global[v_I, v_K, v_i, v_k] = A[v_I, v_K, v_i, v_k]
                                for ax0 in range(16):
                                    with T.block("B_global"):
                                        v_J = T.axis.spatial(16, J)
                                        v_j = T.axis.spatial(16, j)
                                        v_K = T.axis.spatial(16, K)
                                        v_k = T.axis.spatial(16, ax0)
                                        T.reads(B[v_J, v_K, v_j, v_k])
                                        T.writes(B_global[v_J, v_K, v_j, v_k])
                                        B_global[v_J, v_K, v_j, v_k] = B[v_J, v_K, v_j, v_k]
                                for k in range(16):
                                    # tir.Block#0
                                    with T.block("C_update"):
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v_K = T.axis.reduce(16, K)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v_I = T.axis.spatial(1, I_1)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v_J = T.axis.spatial(16, J)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v_i = T.axis.spatial(1, i)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v_j = T.axis.spatial(16, j)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v_k = T.axis.reduce(16, k)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.reads(C[v_I, v_J, v_i, v_j], A_global[v_I, v_K, v_i, v_k], B_global[v_J, v_K, v_j, v_k])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.writes(C[v_I, v_J, v_i, v_j])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        C[v_I, v_J, v_i, v_j] = C[v_I, v_J, v_i, v_j] + T.Cast("int32", A_global[v_I, v_K, v_i, v_k]) * T.Cast("int32", B_global[v_J, v_K, v_j, v_k])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for I_1 in range(1):
                for J in range(16):
                    for i in range(1):
                        for j in range(16):
                            with T.block("D"):
                                v_I = T.axis.spatial(1, I_1)
                                v_J = T.axis.spatial(16, J)
                                v_i = T.axis.spatial(1, i)
                                v_j = T.axis.spatial(16, j)
                                T.reads(C[v_I, v_J, v_i, v_j])
                                T.writes(D[v_I, v_J, v_i, v_j])
                                D[v_I, v_J, v_i, v_j] = T.Cast("int8", C[v_I, v_J, v_i, v_j])
Error message: The stmt tir.Block#0 doesn't match the tensor intrin
The pattern attempting to be matched:
with T.block("C_update", no_realize=True):
    v_K = T.axis.reduce(16)
    v_I = T.axis.spatial(1)
    v_J = T.axis.spatial(16)
    v_i = T.axis.spatial(1)
    v_j = T.axis.spatial(16)
    v_k = T.axis.reduce(16)
    C = T.Buffer((1, 16, 1, 16), "int32")
    A_global = T.Buffer((1, 16, 1, 16), "int8")
    B_global = T.Buffer((16, 16, 16, 16), "int8")
    T.reads(C[v_I, v_J, v_i, v_j], A_global[v_I, v_K, v_i, v_k], B_global[v_J, v_K, v_j, v_k])
    T.writes(C[v_I, v_J, v_i, v_j])
    C[v_I, v_J, v_i, v_j] = C[v_I, v_J, v_i, v_j] + T.Cast("int32", A_global[v_I, v_K, v_i, v_k]) * T.Cast("int32", B_global[v_J, v_K, v_j, v_k])
Does not match the tensorize description:
with T.block("root", no_realize=True):
    out = T.Buffer((1, 16), "int32")
    local_inp_buffer = T.Buffer((1, 16), "int8")
    local_wgt_buffer = T.Buffer((16, 16), "int8")
    T.reads(out[0, 0:16], local_inp_buffer[0, 0:16], local_wgt_buffer[0:16, 0:16])
    T.writes(out[0, 0:16])
    for i, j, k in T.grid(1, 16, 16):
        with T.block("out"):
            v_i = T.axis.spatial(1, 0)
            v_j, v_k = T.axis.remap("SR", [j, k])
            T.reads(out[0, v_j], local_inp_buffer[0, v_k], local_wgt_buffer[v_j, v_k])
            T.writes(out[0, v_j])
            out[0, v_j] = out[0, v_j] + T.Cast("int32", local_inp_buffer[0, v_k]) * T.Cast("int32", local_wgt_buffer[v_j, v_k])
CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]=range(min=v_j, ext=1)Range(000001AE7CE6C600) vs rhs->region[i]=I.Range(0, 16)
BlockNode write buffers do not match: op->writes=[C[v_I, v_J, v_i, v_j]] vs rhs->writes=[out[0, 0:16]]

For readability I also put below the output of the sch.mod.show() on the penultimate line.

@I.ir_module
class Module:
    @T.prim_func
    def gemm(A: T.Buffer((1, 16, 1, 16), "int8"), B: T.Buffer((16, 16, 16, 16), "int8"), D: T.Buffer((1, 16, 1, 16), "int8")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        C = T.alloc_buffer((1, 16, 1, 16), "int32")
        A_global = T.alloc_buffer((1, 16, 1, 16), "int8")
        B_global = T.alloc_buffer((16, 16, 16, 16), "int8")
        for I_1, J, i, j in T.grid(1, 16, 1, 16):
            with T.block("C_init"):
                v_I, v_J, v_i, v_j = T.axis.remap("SSSS", [I_1, J, i, j])
                T.reads()
                T.writes(C[v_I, v_J, v_i, v_j])
                C[v_I, v_J, v_i, v_j] = 0
            for K in range(16):
                for ax0 in range(16):
                    with T.block("A_global"):
                        v_I = T.axis.spatial(1, 0)
                        v_i = T.axis.spatial(1, 0)
                        v_K, v_k = T.axis.remap("SS", [K, ax0])
                        T.reads(A[v_I, v_K, v_i, v_k])
                        T.writes(A_global[v_I, v_K, v_i, v_k])
                        A_global[v_I, v_K, v_i, v_k] = A[v_I, v_K, v_i, v_k]
                for ax0 in range(16):
                    with T.block("B_global"):
                        v_J, v_j, v_K, v_k = T.axis.remap("SSSS", [J, j, K, ax0])
                        T.reads(B[v_J, v_K, v_j, v_k])
                        T.writes(B_global[v_J, v_K, v_j, v_k])
                        B_global[v_J, v_K, v_j, v_k] = B[v_J, v_K, v_j, v_k]
                for k in range(16):
                    with T.block("C_update"):
                        v_K, v_I, v_J, v_i, v_j, v_k = T.axis.remap("RSSSSR", [K, I_1, J, i, j, k])
                        T.reads(C[v_I, v_J, v_i, v_j], A_global[v_I, v_K, v_i, v_k], B_global[v_J, v_K, v_j, v_k])
                        T.writes(C[v_I, v_J, v_i, v_j])
                        C[v_I, v_J, v_i, v_j] = C[v_I, v_J, v_i, v_j] + T.Cast("int32", A_global[v_I, v_K, v_i, v_k]) * T.Cast("int32", B_global[v_J, v_K, v_j, v_k])
        for I_1, J, i, j in T.grid(1, 16, 1, 16):
            with T.block("D"):
                v_I, v_J, v_i, v_j = T.axis.remap("SSSS", [I_1, J, i, j])
                T.reads(C[v_I, v_J, v_i, v_j])
                T.writes(D[v_I, v_J, v_i, v_j])
                D[v_I, v_J, v_i, v_j] = T.Cast("int8", C[v_I, v_J, v_i, v_j])

I understand that many things have changed between the old and new versions of TVM and I have tried in many other ways to reach the same schedule as the tutorial but, of course, none of them worked. The alternative approach that brought me further was to define the intrinsic to take as input a vectors instead of single row matrices, fuse the i and j indices and tensorize before using using the compute_at method. Reached that point I was not able to use compute_at to bring the data loading inside the K loop.

I do not understand what I am doing wrong. Is it still possible to do this kind of scheduling on the modern version of TVM? Do you have any advice on how to proceed?


As a side note I have tried to making it work by defining the intrinsic in a slightly different way because I have noted that the tensorize method seem to handle unit loops differently from non unit loops. To explain better explain myself consider the code below.

from typing import Tuple

from tvm import tir
from tvm.script import tir as T

def make_vta_gemm_intrinsic(n: int) -> Tuple[tir.PrimFunc, tir.PrimFunc]:
    if n < 1:
        raise ValueError("n must be greater than 1")

    @T.prim_func
    def vta_gemm_desc(
            A: T.Buffer((16, 16), "int8"),
            B: T.Buffer((n, 16), "int8"),
            C: T.Buffer((n, 16), "int32")
        ) -> None:

        with T.block("root"):
            T.reads(C[0:n, 0:16], B[0:n, 0:16], A[0:16, 0:16])
            T.writes(C[0:n, 0:16])
            for i, j, k in T.grid(n, 16, 16):
                with T.block("C"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    T.reads(C[vi, vj], B[vi, vk], A[vj, vk])
                    T.writes(C[vi, vj])
                    C[vi, vj] += B[vi, vk].astype("int32") * A[vj, vk].astype("int32")

    @T.prim_func
    def vta_gemm_intrin(
            A: T.Buffer((16, 16), "int8"),
            B: T.Buffer((n, 16), "int8"),
            C: T.Buffer((n, 16), "int32")
        ) -> None:
        T.func_attr({"tir.noalias": T.bool(True)})
        with T.block("root"):
            T.reads(A[0:16, 0:16], B[0:n, 0:16], C[0:n, 0:16])
            T.writes(C[0:n, 0:16])
            T.evaluate(T.call_extern("int32", "SomethingCompute", A.data, B.data, C.data))

    return vta_gemm_desc, vta_gemm_intrin

def test_vta_gemm_intrin(n: int) -> None:
    if n < 1:
        raise ValueError("n must be greater than 1")

    @T.prim_func
    def before(
            A: T.Buffer((128, 128), "int8"),
            B: T.Buffer((128, 128), "int8"),
            C: T.Buffer((128, 128), "int32"),
        ) -> None:
        # with T.block("root")
        for i, j, k in T.grid(128, 128, 128):
            with T.block("update"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                C[vi, vj] += B[vi, vk].astype("int32") * A[vj, vk].astype("int32")

    sch = tir.Schedule(before)
    i, j, k = sch.get_loops(sch.get_block("update"))
    i0, i1 = sch.split(i, (128//n, n))
    j0, j1 = sch.split(j, (128//16, 16))
    k0, k1 = sch.split(k, (128//16, 16))
    sch.reorder(i0, j0, k0, i1, j1, k1)
    sch.mod.show()
    sch.tensorize(i1, "test_vta_gemm_intrin%d" % n)
    sch.mod.show()

n = 4; tir.TensorIntrin.register("test_vta_gemm_intrin%d" % n, *make_vta_gemm_intrinsic(n)); test_vta_gemm_intrin(n)
n = 2; tir.TensorIntrin.register("test_vta_gemm_intrin%d" % n, *make_vta_gemm_intrinsic(n)); test_vta_gemm_intrin(n)
n = 1; tir.TensorIntrin.register("test_vta_gemm_intrin%d" % n, *make_vta_gemm_intrinsic(n)); test_vta_gemm_intrin(n)

It fails on the last line with the following error.

tvm.tir.schedule.schedule.ScheduleError: ScheduleError: An error occurred in the schedule primitive 'tensorize'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
        T.func_attr({"global_symbol": "before"})
        A = T.match_buffer(A_handle, (128, 128), "int8")
        B = T.match_buffer(B_handle, (128, 128), "int8")
        C = T.match_buffer(C_handle, (128, 128), "int32")
        with T.block("root"):
            T.reads()
            T.writes()
            for i_0 in range(128):
                for j_0 in range(8):
                    for k_0 in range(8):
                        for i_1 in range(1):
                            for j_1 in range(16):
                                for k_1 in range(16):
                                    with T.block("update"):
                                        vi = T.axis.spatial(128, i_0 + i_1)
                                        vj = T.axis.spatial(128, j_0 * 16 + j_1)
                                        vk = T.axis.reduce(128, k_0 * 16 + k_1)
                                        T.reads(C[vi, vj], B[vi, vk], A[vj, vk])
                                        T.writes(C[vi, vj])
                                        C[vi, vj] = C[vi, vj] + T.Cast("int32", B[vi, vk]) * T.Cast("int32", A[vj, vk])
Error message: The stmt tir.For#0 doesn't match the tensor intrin
The pattern attempting to be matched:
for k_1 in range(16):
    j_1 = T.int32()
    with T.block("update"):
        vj_i = T.axis.spatial(16, j_1)
        vk_i = T.axis.reduce(16, k_1)
        C = T.Buffer((128, 128), "int32")
        vi_o = T.int32()
        vj_o = T.int32()
        B = T.Buffer((128, 128), "int8")
        vk_o = T.int32()
        A = T.Buffer((128, 128), "int8")
        T.reads(C[vi_o, vj_o * 16 + vj_i], B[vi_o, vk_o * 16 + vk_i], A[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
        T.writes(C[vi_o, vj_o * 16 + vj_i])
        C[vi_o, vj_o * 16 + vj_i] = C[vi_o, vj_o * 16 + vj_i] + T.Cast("int32", B[vi_o, vk_o * 16 + vk_i]) * T.Cast("int32", A[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
Does not match the tensorize description:
for k in range(16):
    j = T.int32()
    with T.block("C"):
        vi = T.axis.spatial(1, 0)
        vj = T.axis.spatial(16, j)
        vk = T.axis.reduce(16, k)
        C = T.Buffer((1, 16), "int32")
        B = T.Buffer((1, 16), "int8")
        A = T.Buffer((16, 16), "int8")
        T.reads(C[0, vj], B[0, vk], A[vj, vk])
        T.writes(C[0, vj])
        C[0, vj] = C[0, vj] + T.Cast("int32", B[0, vk]) * T.Cast("int32", A[vj, vk])
CompareArray array size mismatch. lhs.size()=2 vs rhs.size()=3
BlockRealizeNode iter_values do not match: op->iter_values=[j_1, k_1] vs rhs->iter_values=[0, j, k]

I do not understand why the unit case should work differently from the other two. Is this a bug or I am doing something wrong?

Ok I think I have managed to do make the schedule work by: doing the operations in a different order, using an intrinsic that takes as input a vector instead of a 1xN matrix and using fuse to remove the unit loop. Below there is the snippet of code that I am using.

sch = tvm.tir.Schedule(mod)
sch.work_on('gemm')
C_block = sch.get_block("C")
I, J, i, j, K, k = sch.get_loops(C_block)
sch.reorder(K, I, J, i, j, k)
A_cache = sch.cache_read(C_block, 0, env.acc_scope)
B_cache = sch.cache_read(C_block, 1, env.acc_scope)
sch.set_scope(C_block, 0, env.acc_scope)
C_init = sch.decompose_reduction(C_block, K)
ij = sch.fuse(i, j)
sch.compute_at(A_cache, K)
sch.compute_at(B_cache, K)
sch.mod['gemm'].show(ir_prefix="IR")
sch.annotate(sch.get_loops(A_cache)[-1], env.dma_copy, True)
sch.annotate(sch.get_loops(B_cache)[1], env.dma_copy, True)
sch.annotate(sch.get_loops(sch.get_block("D"))[0], env.dma_copy, True)
sch.tensorize(ij, "test_vta_gemm_intrin1_scoped")

It is still not clear to me why if I do not use fuse together with the ad hoc intrinsic the tensorize method fails.