Describing Tensorization Intrins in TIR

Hello,

I am trying to learn how to use TIR for scheduling operations and am running into some issues regarding how to describe the intrins. I want to tensorize a 16x16x16 matmul into the tensorized block. This is the intrin I defined:

def get_intrin_gemm_tir(
    dim_i: int,
    dim_k: int,
    dim_j: int,
):

    @T.prim_func
    def my_matmul_desc(a: T.handle, b:T.handle, d:T.handle) -> None:
        A = T.match_buffer(a, (dim_i, dim_k), dtype="int8", scope="local.spad")
        B = T.match_buffer(b, (dim_k, dim_j), dtype="int8", scope="local.spad_w")
        D = T.match_buffer(d, (dim_i, dim_j), dtype="int8", scope="local.acc")

        with T.block("root"):
            T.reads(A[0:16, 0:16], B[0:16, 0:16], D[0:16, 0:16])
            T.writes(D[0:16, 0:16])

            for i, j, k in T.grid(dim_i, dim_j, dim_k):
                with T.block("res"):
                    vii, vjj, vkoi = T.axis.remap("SSR", [i, j, k])
                    D[vii, vjj] = D[vii, vjj] + A[vii, vkoi] * B[vjj, vkoi]

    @T.prim_func
    def my_matmul_impl(a: T.handle, b: T.handle, d:T.handle) -> None:
        sai = T.int32()
        sak = T.int32()
        sbk = T.int32()
        sbj = T.int32()
        sdi = T.int32()
        sdj = T.int32()

        A = T.match_buffer(a, (dim_i, dim_k), strides=[sai, sak], dtype="int8", scope="local.spad")
        B = T.match_buffer(b, (dim_k, dim_j), strides=[sbk, sbj], dtype="int8", scope="local.spad_w")
        D = T.match_buffer(d, (dim_i, dim_j), strides=[sdi, sdj], dtype="int8", scope="local.acc")

        with T.block("root"):
            T.reads(A[0:16, 0:16], B[0:16, 0:16], D[0:16, 0:16])
            T.writes(D[0:16, 0:16])
            T.evaluate(
                T.call_extern("test", dtype="" )
            )

    return my_matmul_desc, my_matmul_impl

But when I run this on my schedule I get this error:

Error message: The stmt tir.Block#0 doesn't match the tensor intrin
The pattern attempting to be matched:
with T.block("res", no_realize=True):
    v_i_i = T.axis.spatial(16)
    v_j_i = T.axis.spatial(16)
    v_k_o_i = T.axis.reduce(16)
    res_local_acc = T.Buffer((128, 64), "int8", scope="local.acc")
    v_i_o = T.int32()
    v_j_o = T.int32()
    a_in_local_spad = T.Buffer((128, 256), "int8", scope="local.spad")
    v_k_o_o = T.int32()
    b_in_local_spad_w = T.Buffer((256, 64), "int8", scope="local.spad_w")
    bias_in_local_acc = T.Buffer((128, 64), "int32", scope="local.acc")
    T.reads(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i], a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i], b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i], bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
    T.writes(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
    res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] = res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] + (a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i] * b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i] + T.Cast("int8", bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i]))
Does not match the tensorize description:
with T.block("root", no_realize=True):
    A = T.Buffer((16, 16), "int8", scope="local.spad")
    B = T.Buffer((16, 16), "int8", scope="local.spad_w")
    D = T.Buffer((16, 16), "int8", scope="local.acc")
    T.reads(A[0:16, 0:16], B[0:16, 0:16], D[0:16, 0:16])
    T.writes(D[0:16, 0:16])
    for i, j, k in T.grid(16, 16, 16):
        with T.block("res"):
            vii, vjj, vkoi = T.axis.remap("SSR", [i, j, k])
            T.reads(D[vii, vjj], A[vii, vkoi], B[vjj, vkoi])
            T.writes(D[vii, vjj])
            D[vii, vjj] = D[vii, vjj] + A[vii, vkoi] * B[vjj, vkoi]
CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]=T.Range(v_i_o * 16 + v_i_i, 1) vs rhs->region[i]=T.Range(0, 16)
BlockNode write buffers do not match: op->writes=[res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i]] vs rhs->writes=[D[0:16, 0:16]]

I don’t know why the buffers like res_local_acc are so large here, if I look at the block that I want to tensorize over it looks like this:

for k_o_1, j_1, i_1 in T.grid(16, 16, 16):
	with T.block("res"):
		v_i_i, v_j_i, v_k_o_i = T.axis.remap("SSR", [i_1, j_1, k_o_1])
		T.reads(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i], a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i], b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i], bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
		T.writes(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
		res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] = res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] + (a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i] * b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i] + T.Cast("int8", bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i]))

It does only use 16x16 elements from each buffer. What am I doing wrong here?

For completeness, here is the code to generate the schedule:

import numpy as np
import tvm

from tvm import te
from tvm import topi

from experiments_intrins import *

dim_I = 128
dim_K = 256
dim_J = 64

factor = 16

inp_shape =  (dim_I, dim_K)
wght_shape = (dim_K, dim_J)
out_shape =  (dim_I, dim_J)

# calculate inp @ wght + bias = out
inp = te.placeholder(inp_shape, dtype="int8", name="a_in")
wght = te.placeholder(wght_shape, dtype="int8", name="b_in")
bias = te.placeholder((out_shape), dtype="int32", name="bias_in")
rk = te.reduce_axis((0, dim_K), name="k_o")

res = te.compute(
    out_shape,
    lambda i, j: te.sum(
        inp[i, rk].astype(ENV.inp_dtype) * wght[rk, j].astype(ENV.inp_dtype)
        + bias[i, j].astype(ENV.inp_dtype),
        axis=[rk],
    ),

    name="res",
    tag="dense",
)

dense_op = res.op
data, weight, bias_op = res.op.input_tensors

func = te.create_prim_func([inp, wght, bias, res])
sch = tvm.tir.Schedule(func)

# Extract loops from the schedule
block_res = sch.get_block("res")
i, j, ko = sch.get_loops(block_res)

cdata = sch.cache_read(block_res, read_buffer_index=0, storage_scope="local.spad")
cweight = sch.cache_read(block_res, read_buffer_index=1, storage_scope="local.spad_w")
cbias = sch.cache_read(block_res, read_buffer_index=2, storage_scope="local.acc")
#bias_op.set_scope(block_res, buffer_index="bias_in", storage_scope="local.acc")
res_local = sch.cache_write(block_res, write_buffer_index=0, storage_scope="local.acc")

i_outer, i_inner = sch.split(i, factors=[None, factor])
j_outer, j_inner = sch.split(j, factors=[None, factor])
k_outer, k_inner = sch.split(ko,factors=[None, factor])

sch.reorder(i_outer, j_outer, k_outer,
    k_inner, j_inner, i_inner)

block_mm = sch.blockize(k_inner)

sch.reverse_compute_at("res_local.acc", k_outer)
sch.compute_at("b_in_local.spad_w", k_outer)
sch.compute_at("a_in_local.spad", k_outer)
sch.compute_at("bias_in_local.acc", k_outer)

print(sch.mod.script())

gemm_intrin_desc, gemm_intrin_impl= get_intrin_gemm_tir(factor, factor, factor)
tvm.tir.TensorIntrin.register("gemm_op", gemm_intrin_desc, gemm_intrin_impl)

res_block = sch.get_block("res")
sch.tensorize(res_block, "gemm_op")

print(sch.mod.script())

There are two problems:

  1. Please decompose_reduction before tensorize. Since your tensor intrinsic only does accumulative computation but no initialization.

  2. Your tensor intrinsic only support C += A * B, but does not support C += A * B + bias. To be honest, it’s super strange that you define the computation as sum(inp[i, rk] * wght[rk, j] + bias[i, j]), which means you add bias[i, j] k times

Thanks for the input! I tried adding the decompose_reduction operation, and now I get a slightly different error:

CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]=T.Range(v_i_o * 16 + v_i_i, 1) vs rhs->region[i]=T.Range(0, 16)
BlockNode write buffers do not match: op->writes=[res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i]] vs rhs->writes=[D[0:16, 0:16

Do the TIR primitives really expect me to specify a detailed access pattern to the buffer, or is there still something else wrong?

Regarding the bias, the hardware I am using will take care of that by preloading the bias into the accumulator and then adding values on top of that. Because of that, I am not sure how relevant the exact bias operation I am scheduling is. But I’ll look into adding a separate te.compute stage for that, thanks for the hint!

Finally, since initialization and bias are handled implicitly by overwriting the accumulator memory, can I represent this in an intrinsic? This should eliminate the need for the decompose_reduction if I understand its use correctly.

I’ve experimented a bit with an updated version of the bias operation. Here is the resulting schedule:

@I.ir_module
class Module:
    @T.prim_func
    def main(a_in: T.Buffer((128, 256), "int8"), b_in: T.Buffer((256, 64), "int8"), bias_in: T.Buffer((128, 64), "int32"), bias_add: T.Buffer((128, 64), "int32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        res = T.alloc_buffer((128, 64), "int32")
        a_in_local_spad = T.alloc_buffer((128, 256), "int8", scope="local.spad")
        b_in_local_spad_w = T.alloc_buffer((256, 64), "int8", scope="local.spad_w")
        bias_add_local_acc = T.alloc_buffer((128, 64), "int32", scope="local.acc")
        for i_0, j_0, k_o_0 in T.grid(8, 4, 16):
            for ax0, ax1 in T.grid(16, 16):
                with T.block("b_in_local.spad_w"):
                    v0 = T.axis.spatial(256, k_o_0 * 16 + ax0)
                    v1 = T.axis.spatial(64, j_0 * 16 + ax1)
                    T.reads(b_in[v0, v1])
                    T.writes(b_in_local_spad_w[v0, v1])
                    b_in_local_spad_w[v0, v1] = b_in[v0, v1]
            for ax0, ax1 in T.grid(16, 16):
                with T.block("a_in_local.spad"):
                    v0 = T.axis.spatial(128, i_0 * 16 + ax0)
                    v1 = T.axis.spatial(256, k_o_0 * 16 + ax1)
                    T.reads(a_in[v0, v1])
                    T.writes(a_in_local_spad[v0, v1])
                    a_in_local_spad[v0, v1] = a_in[v0, v1]
            for k_o_1 in range(16):
                for j_1, i_1 in T.grid(16, 16):
                    with T.block("res"):
                        v_i = T.axis.spatial(128, i_0 * 16 + i_1)
                        v_j = T.axis.spatial(64, j_0 * 16 + j_1)
                        v_k_o = T.axis.reduce(256, k_o_0 * 16 + k_o_1)
                        T.reads(a_in_local_spad[v_i, v_k_o], b_in_local_spad_w[v_k_o, v_j])
                        T.writes(res[v_i, v_j])
                        with T.init():
                            res[v_i, v_j] = 0
                        res[v_i, v_j] = res[v_i, v_j] + T.Cast("int32", a_in_local_spad[v_i, v_k_o]) * T.Cast("int32", b_in_local_spad_w[v_k_o, v_j])
                for ax0, ax1 in T.grid(16, 16):
                    with T.block("bias_add"):
                        v_i = T.axis.spatial(128, i_0 * 16 + ax0)
                        v_j = T.axis.spatial(64, j_0 * 16 + ax1)
                        T.reads(res[v_i, v_j], bias_in[v_i, v_j])
                        T.writes(bias_add_local_acc[v_i, v_j])
                        bias_add_local_acc[v_i, v_j] = res[v_i, v_j] + bias_in[v_i, v_j]
            for ax0, ax1 in T.grid(16, 16):
                with T.block("bias_add_local.acc"):
                    v0 = T.axis.spatial(128, i_0 * 16 + ax0)
                    v1 = T.axis.spatial(64, j_0 * 16 + ax1)
                    T.reads(bias_add_local_acc[v0, v1])
                    T.writes(bias_add[v0, v1])
                    bias_add[v0, v1] = bias_add_local_acc[v0, v1]

I want to tensorize the reduction and bias operation into my intrinsic, i.e. this part here:

        for k_o_1 in range(16):
            for j_1, i_1 in T.grid(16, 16):
                with T.block("res"):
                    v_i = T.axis.spatial(128, i_0 * 16 + i_1)
                    v_j = T.axis.spatial(64, j_0 * 16 + j_1)
                    v_k_o = T.axis.reduce(256, k_o_0 * 16 + k_o_1)
                    T.reads(a_in_local_spad[v_i, v_k_o], b_in_local_spad_w[v_k_o, v_j])
                    T.writes(res[v_i, v_j])
                    with T.init():
                        res[v_i, v_j] = 0
                    res[v_i, v_j] = res[v_i, v_j] + T.Cast("int32", a_in_local_spad[v_i, v_k_o]) * T.Cast("int32", b_in_local_spad_w[v_k_o, v_j])
            for ax0, ax1 in T.grid(16, 16):
                with T.block("bias_add"):
                    v_i = T.axis.spatial(128, i_0 * 16 + ax0)
                    v_j = T.axis.spatial(64, j_0 * 16 + ax1)
                    T.reads(res[v_i, v_j], bias_in[v_i, v_j])
                    T.writes(bias_add_local_acc[v_i, v_j])
                    bias_add_local_acc[v_i, v_j] = res[v_i, v_j] + bias_in[v_i, v_j]

How can I describe this? I found that I can’t use blockize on k_o_1 because that only works when there is only one child block, I also can’t tensorize over a loop that contains multiple blocks, so how do I handle this?

For completeness, here is my attempt at declaring the intrins:

def get_intrin_gemm_tir(
    dim_i: int,
    dim_k: int,
    dim_j: int,
):
    @T.prim_func
    def my_matmul_desc(a: T.handle, b:T.handle, res:T.handle, bias_in:T.handle, res_bias_local:T.handle) -> None:
        A = T.match_buffer(a, (dim_i, dim_k), dtype="int8", scope="local.spad")
        B = T.match_buffer(b, (dim_k, dim_j), dtype="int8", scope="local.spad_w")
        Bias_in = T.match_buffer(bias_in, (dim_i, dim_j), dtype="int32", scope="local.acc")
        Res     = T.match_buffer(res,     (dim_i, dim_j), dtype="int32", scope="local.acc")
        Res_bias_local= T.match_buffer(res_bias_local, (dim_i, dim_j), dtype="int32", scope="local.acc")

        for ko in range(16):
            for j, i in T.grid(16, 16):
                with T.block("res"):
                    T.reads(A[0:16, 0:16], B[0:16, 0:16], Res[0:16, 0:16])
                    T.writes(Res[0:16, 0:16])
                    vii, vjj, vkoi = T.axis.remap("SSR", [i, j, ko])
                    with T.init():
                        Res[vii, vjj] = 0
                    Res[vii, vjj] = Res[vii, vjj] + (A[vii, vkoi] * B[vjj, vkoi])
            for ax0, ax1 in T.grid(16, 16):
                with T.block("bias_add"):
                    T.reads(Res[0:16, 0:16], Bias_in[0:16, 0:16])
                    T.writes(Res_bias_local[0:16, 0:16])
                    axi, axj = T.axis.remap("SS", [ax0, ax1])
                    Res_bias_local[axi, axj] = Res[axi, axj] + Bias_in[axi, axj]



    @T.prim_func
    def my_matmul_impl(a: T.handle, b:T.handle, res:T.handle, bias_in:T.handle, res_bias_local:T.handle) -> None:

        A = T.match_buffer(a, (dim_i, dim_k), dtype="int8", scope="local.spad")
        B = T.match_buffer(b, (dim_k, dim_j), dtype="int8", scope="local.spad_w")
        Bias_in = T.match_buffer(bias_in, (dim_i, dim_j), dtype="int32", scope="local.acc")
        Res     = T.match_buffer(res,     (dim_i, dim_j), dtype="int32", scope="local.acc")
        Res_bias_local= T.match_buffer(res_bias_local, (dim_i, dim_j), dtype="int32", scope="local.acc")

        with T.block("root"):
            T.reads(A[0:16, 0:16], B[0:16, 0:16], Res[0:16, 0:16], Bias_in[0:16, 0:16])
            T.writes(Res_bias_local[0:16, 0:16])
            T.evaluate(
                T.call_extern("test", dtype="" )
            )

    return my_matmul_desc, my_matmul_impl

@Hzfengsy Do you have an idea how I can model this? It is supposed to be a regular WS dataflow on a systolic array. I tried using compute_inline to move the “bias_add” into the “res” block, but that doesn’t work: Error message: The consumer block tir.Block#0 to be inlined is required to have only a single producer block, and the producer block should be a complete block who has only a single consume, I think because the reduction block is not a complete block.

Using decompose_reduction also doesn’t yield the results I would like to have here, since I implicitly handle the init step through a dedicated accumulate memory. Really, the most straightforward way seem to modify the init statement, but I can’t find any documentation.