Error during tensorization: Cannot bind a compact buffer res to a strided buffer res_slice with strides

Hello,

I am trying to figure out how to I can tensorize for GEMM kernels to familiarize myself with TVM, but I am running into some issues. I started with a simple matrix-matrix multiplication where I tile every loop into two levels, reorder them and tensorize over the resulting inner three levels. The split looks like this:

res = te.compute(
    c_shape,
    lambda r_o, c_o: te.sum(
        a[r_o, k_o].astype(ENV.inp_dtype)
        * b[k_o, c_o].astype(ENV.inp_dtype),
        axis=[k_o],
    ),
    name="res",
    tag="dense",
)

sch = te.create_schedule(res.op)

outer_m, inner_m = sch[res].split(res.op.axis[0], factor=factor)
outer_n, inner_n = sch[res].split(res.op.axis[1], factor=factor)
outer_k, inner_k = sch[res].split(res.op.reduce_axis[0], factor=factor)

sch[res].reorder(outer_m, outer_n, outer_k,
                inner_m, inner_n, inner_k)

sch[res].tensorize(inner_m, intrin_gemm(factor, factor, factor))

In the intrinsic I want to use the buffers are declared like this:

a = te.placeholder(a_shape, dtype=ENV.inp_dtype, name="ifmap_tile")
b = te.placeholder(b_shape, dtype=ENV.wgt_dtype, name="kernel_tile")
d = te.placeholder(d_shape, ENV.inp_dtype, name="out_tile")

Aa = tvm.tir.decl_buffer(a.shape, a.dtype, name="Ifmap_buf", offset_factor=1, strides=[te.var("aa_s1"), 1])
Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="Kernel_buf", offset_factor=1, strides=[te.var("bb_s1"), 1])
Dd = tvm.tir.decl_buffer(d.shape, d.dtype, name="Out_buf", offset_factor=1, strides=[te.var("dd_s1"), 1])
res = te.compute(
    d_shape,
    lambda no, mo: te.sum(
        a[no, rc].astype(ENV.inp_dtype)
        * b[rc, mo].astype(ENV.inp_dtype),
        axis=[rc],
    ),
    name="res"
)

The shapes are all 16x16. During the tensorize I run into this error: Cannot bind a compact buffer res to a strided buffer res_slice with strides [64, 1] .

I don’t understand why I run into this error. According to this tutorial using the var should automatically capture the stride, so what I am missing here?

hi, brother, did you solve the problem?