Understanding Tensorization Details

I am trying to wrap my head around the tensorization function in order to integrate a custom accelerator, but I am running into a lot of problems and hope that someone can provide some clarity on what is going on under the hood.

I am starting out with a simple matmul:

I = 32
K = 128
J = 64

factor = 16

a_shape = (I, K)
b_shape = (K, J)
c_shape = (I, J)

# calculate A @ B + D = C
a = te.placeholder(a_shape, dtype="int8", name="a_in")
b = te.placeholder(b_shape, dtype="int8", name="b_in")
c = te.placeholder(c_shape, dtype="int32", name="c_out")

k_o = te.reduce_axis((0, K), name="k_o")

res = te.compute(
    lambda r_o, c_o: te.sum(
        a[r_o, k_o].astype(ENV.inp_dtype)
        * b[k_o, c_o].astype(ENV.inp_dtype),

That is then split into multiple levels:

sch = te.create_schedule(res.op)
outer_i, inner_i = sch[res].split(res.op.axis[0], factor=factor)
outer_j, inner_j = sch[res].split(res.op.axis[1], factor=factor)
outer_k, inner_k = sch[res].split(res.op.reduce_axis[0], factor=factor)

In the end, I want to use tensorize to map the three innermost loops onto hardware. My initial understanding was that tensorize just tries to map a general “three nested loops with upper bound equals 16” structure. But that turned out to not be the case. If we define the computation and reorder the loops like this

sch[res].reorder(outer_i, outer_j, outer_k,
                 inner_k, inner_j, inner_i)

And use this intrinsic:

def intrin_gemm(
    n: int,
    c: int,
    m: int,

    """ GEMM of NxC and CxM matrices"""
    a_shape = (n, c)
    b_shape = (c, m)
    d_shape= (n, m)

    rc = te.reduce_axis((0, c), name="ric")
    a = te.placeholder(a_shape, dtype=ENV.inp_dtype, name="ifmap_tile")
    b = te.placeholder(b_shape, dtype=ENV.wgt_dtype, name="kernel_tile")
    d = te.placeholder(d_shape, ENV.inp_dtype, name="out_tile")
    Aa = tvm.tir.decl_buffer(a.shape, a.dtype, name="Ifmap_buf",   strides=[te.var("aa_s1"), te.var("aa_s2")])
    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="Kernel_buf",  strides=[te.var("bb_s1"), te.var("bb_s2")])
    Dd = tvm.tir.decl_buffer(d.shape, d.dtype, name="Out_buf",     strides=[te.var("dd_s1"), te.var("dd_s2")])

    res = te.compute(
        lambda no, mo: te.sum(
            a[no, rc].astype(ENV.inp_dtype)
            * b[rc, mo].astype(ENV.inp_dtype),

    def intrin_func(ins, outs):
        ifm, ker = ins
        res = outs[0]
        def _body():
            ib = tvm.tir.ir_builder.create()
            return ib.get()

        def _reduce_reset():
            ib = tvm.tir.ir_builder.create()
            return ib.get()

        def _reduce_update():
            ib = tvm.tir.ir_builder.create()
            return ib.get()

        return _body(), _reduce_reset(), _reduce_update()

    return te.decl_tensor_intrin(res.op, intrin_func, binds={a: Aa, b: Bb, d: Dd})

I get an error: Cannot bind a compact buffer res to a strided buffer res_slice with strides [64, 1]. So, tensorize requires some information about the buffer layout and access pattern. But according to the tensorize tutorial, TVM should be able to figure out the strides on its own if I let them bind to a te.var, so why is that not the case?

The second thing I do not get is how loop order effects the outcome of tensorization. If I reorder differently:

sch[res].reorder(outer_i, outer_j, outer_k, inner_k, inner_j, inner_i)

I get another error: TVMError: Bind have an unmet assertion: T.bool(False), on argument tensir_intrin.reduction.extent. I don’t understand why the different order would not result in the same error, I thought tensorize just replaces the for loops, regardless of their order. So if that is not the case, how exactly does tensorize operate?

After some more experimentation with the GEMV example, I want to also ask what exactly offset_factor, and elem_offset denote. If I don’t manually fix offset_factor=1 in the tutorial, I get this error:

Check failed: (is_zero(value->elem_offset)) is false: Trying to bind a Buffer with offset into one without offset required elem_offset=0, provided elem_offset=i * 64

I did not find any documentation on how these two values are interacting, or how the default values are inferred.