[Maybe A Bug] Failed for using vectorize stage

I am using TVM stages to optimize dense process. First the input and weight matrices are all packed and then the blocks of these matrices are used to do calculation. After defining the compute and schedule, the lower ir is as follows:

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {C: Buffer(C_2: Pointer(float32), float32, [128, 1024], []),
             B: Buffer(B_2: Pointer(float32), float32, [1024, 1536], []),
             A: Buffer(A_2: Pointer(float32), float32, [128, 1536], [])}
  buffer_map = {A_1: A, B_1: B, C_1: C} {
  allocate(packedB: Pointer(global float32), float32, [262144]), storage_scope = global;
  allocate(packedA: Pointer(global float32), float32, [16384]), storage_scope = global {
    for (m.outer.init: int32, 0, 2) {
      for (n.inner.outer.init: int32, 0, 64) {
        for (m.inner.outer.init: int32, 0, 8) {
          for (m.inner.inner.init: int32, 0, 8) {
            for (n.inner.inner.init: int32, 0, 16) {
              C_2[(((((m.outer.init*65536) + (m.inner.outer.init*8192)) + (m.inner.inner.init*1024)) + (n.inner.outer.init*16)) + n.inner.inner.init)] = 0f32
            }
          }
        }
      }
    }
    for (k.outer: int32, 0, 6) {
      for (nn: int32, 0, 64) {
        for (ki: int32, 0, 256) {
          for (ni: int32, 0, 16) {
            packedB[(((nn*4096) + (ki*16)) + ni)] = (float32*)B_2[((((nn*24576) + (ni*1536)) + (k.outer*256)) + ki)]
          }
        }
      }
      for (m.outer: int32, 0, 2) {
        for (mm: int32, 0, 8) {
          for (ki_1: int32, 0, 256) {
            for (mi: int32, 0, 8) {
              packedA[(((mm*2048) + (ki_1*8)) + mi)] = (float32*)A_2[(((((m.outer*98304) + (mm*12288)) + (mi*1536)) + (k.outer*256)) + ki_1)]
            }
          }
        }
        for (n.inner.outer: int32, 0, 64) {
          for (m.inner.outer: int32, 0, 8) {
            for (m.inner.inner: int32, 0, 8) {
              for (k.inner: int32, 0, 256) {
                for (n.inner.inner: int32, 0, 16) {
                  C_2[(((((m.outer*65536) + (m.inner.outer*8192)) + (m.inner.inner*1024)) + (n.inner.outer*16)) + n.inner.inner)] = ((float32*)C_2[(((((m.outer*65536) + (m.inner.outer*8192)) + (m.inner.inner*1024)) + (n.inner.outer*16)) + n.inner.inner)] + ((float32*)packedA[(((m.inner.outer*2048) + (k.inner*8)) + m.inner.inner)]*(float32*)packedB[(((n.inner.outer*4096) + (k.inner*16)) + n.inner.inner)]))
                }
              }
            }
          }
        }
      }
    }
  }
}

As you can see, the three innermost loop calculates a 8x16 block, so next I try to define a micro kernel using vectorize stage and call an extern C function in the stage. The vectorize function is defined as follows:

def _gemm_micro_kernel(tile_m, tile_k, tile_n):
    a = te.placeholder((tile_k, tile_m), name="kernel_a")
    b = te.placeholder((tile_k, tile_n), name="kernel_b")
    k = te.reduce_axis((0, tile_k), name="kernel_k")
    c = te.compute((tile_m, tile_n), lambda i, j: te.sum(
        a[k, i] * b[k, j], axis=k), name="kernel_c")

    Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="kernel_a_buffer",
                             offset_factor=1, strides=[te.var('s1'), te.var('s2')])
    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="kernel_b_buffer",
                             offset_factor=1, strides=[te.var('s1'), te.var('s2')])
    Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="kernel_c_buffer",
                             offset_factor=1, strides=[te.var('s1'), te.var('s2')])

    def gemm_kernel(ins, outs):
        """
        If tensorization includes all the reduce axes, body function will be invoked;
        Otherwise, _reduce_reset and _reduce_update function will be invoked.
        """
        data, weight = ins
        output = outs[0]

        def _body():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    "gemm8x16_kernel",
                    data.access_ptr("r"),
                    weight.access_ptr("r"),
                    output.access_ptr("w"),
                    tile_k,
                    output.strides[0]
                )
            )
            return ib.get()

        def _reduce_reset():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    "gemm8x16_reset",
                    output.access_ptr("w"),
                    output.strides[0]
                )
            )
            return ib.get()

        def _reduce_update():
            return _body()

        return _body(), _reduce_reset(), _reduce_update()

    return te.decl_tensor_intrin(c.op, gemm_kernel, name="gemm_micro_kernel", binds={a: Ab, b: Bb, c: Cb})

After binding this function to m.inner.inner loop, an error occurred:

Traceback (most recent call last):
  File "x.py", line 152, in <module>
    main()
  File "x.py", line 131, in main
    print(tvm.lower(s, [A, B, C], simple_mode=True))
  File "/tvm/python/tvm/driver/build_module.py", line 133, in lower
    return ffi.lower_schedule(inp, args, name, binds, simple_mode)
  File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  10: TVMFuncCall
  9: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)>::AssignTypedLambda<tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}>(tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  8: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  7: tvm::ScheduleToModule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&)
  6: tvm::te::ScheduleOps(tvm::te::Schedule, tvm::runtime::Map<tvm::tir::IterVar, tvm::Range, void, void>, bool)
  5: tvm::te::MakePipeline(tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, tvm::tir::Stmt, bool)
  4: tvm::te::ComputeOpNode::BuildProvide(tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, bool) const
  3: tvm::te::MakeTensorize(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, bool)
  2: tvm::te::VerifyTensorizeBody(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::PrimExpr, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::PrimExpr> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::te::Tensor, tvm::runtime::Array<tvm::Range, void>, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::runtime::Array<tvm::Range, void> > > > const&, tvm::te::TensorIntrin const&)
  1: tvm::te::MatchTensorizeBody(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::te::Tensor, tvm::runtime::Array<tvm::Range, void>, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::runtime::Array<tvm::Range, void> > > > const&, tvm::te::TensorIntrin const&, tvm::runtime::Map<tvm::tir::Var, tvm::Range, void, void>*)
  0: tvm::te::TensorIntrinMatcher::Init(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::te::Tensor, tvm::runtime::Array<tvm::Range, void>, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::runtime::Array<tvm::Range, void> > > > const&, tvm::te::TensorIntrin const&, tvm::runtime::Map<tvm::tir::Var, tvm::Range, void, void>*)
  File "/tvm/src/te/operation/tensorize.cc", line 227
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (is_one(canonical_extent)) is false: Tensorize gemm_micro_kernel: Input dimension mismatch with tensor intrin  expected shape=[256, 8], given region=[range(min=floordiv(((m.outer*64) + (m.inner.outer*8)), 64), ext=((floordiv(((m.inner.outer*8) + 7), 64) + 1) - floordiv(m.inner.outer, 8))), range(min=floordiv((k.outer*256), 256), ext=1), range(min=0, ext=8), range(min=0, ext=256), range(min=0, ext=8)]

The exception seems to be a shape mismatch error. But when doing calculation for the three innermost loop, the shape for input, weight and output matrix are all fixed for this algorithm. So is this a bug for vectorize stage? My TVM version is v0.8.0.

You can reproduce this using the script below:

import tvm
import tvm.testing
from tvm import te
import numpy
import tempfile

# The size of the matrix
# (M, K) x (K, N)
# You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.
M = 128
K = 1536
N = 1024

# The default tensor type in tvm
dtype = "float32"

# using Intel AVX2(Advanced Vector Extensions) ISA for SIMD
# To get the best performance, please change the following line
# to llvm -mcpu=core-avx2, or specific type of CPU you use
target = "llvm -mcpu=core-avx2"
dev = tvm.device(target, 0)


def _gemm_micro_kernel(tile_m, tile_k, tile_n):
    a = te.placeholder((tile_k, tile_m), name="kernel_a")
    b = te.placeholder((tile_k, tile_n), name="kernel_b")
    k = te.reduce_axis((0, tile_k), name="kernel_k")
    c = te.compute((tile_m, tile_n), lambda i, j: te.sum(
        a[k, i] * b[k, j], axis=k), name="kernel_c")

    Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="kernel_a_buffer",
                             offset_factor=1, strides=[te.var('s1'), te.var('s2')])
    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="kernel_b_buffer",
                             offset_factor=1, strides=[te.var('s1'), te.var('s2')])
    Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="kernel_c_buffer",
                             offset_factor=1, strides=[te.var('s1'), te.var('s2')])

    def gemm_kernel(ins, outs):
        """
        If tensorization includes all the reduce axes, body function will be invoked;
        Otherwise, _reduce_reset and _reduce_update function will be invoked.
        """
        data, weight = ins
        output = outs[0]

        def _body():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    "gemm8x16_kernel",
                    data.access_ptr("r"),
                    weight.access_ptr("r"),
                    output.access_ptr("w"),
                    tile_k,
                    output.strides[0]
                )
            )
            return ib.get()

        def _reduce_reset():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    "gemm8x16_reset",
                    output.access_ptr("w"),
                    output.strides[0]
                )
            )
            return ib.get()

        def _reduce_update():
            return _body()

        return _body(), _reduce_reset(), _reduce_update()

    return te.decl_tensor_intrin(c.op, gemm_kernel, name="gemm_micro_kernel", binds={a: Ab, b: Bb, c: Cb})


def main():

    A = te.placeholder((M, K), name="A")
    B = te.placeholder((N, K), name="B")

    # -----------------------------
    # GoToBlas Algorithm
    # -----------------------------
    nc = 1024
    mc = 64
    kc = 256

    mr = 8
    nr = 16

    mc_packed = M // mc
    nc_packed = N // nc
    kc_packed = K // kc

    packedA = te.compute((mc_packed, kc_packed, mc // mr, kc, mr), lambda mo,
                         ko, mm, ki, mi: A[mo*mc + mm*mr + mi, ko*kc + ki], name="packedA")

    packedB = te.compute((nc_packed, kc_packed, nc // nr, kc, nr), lambda no,
                         ko, nn, ki, ni: B[no * nc + nn * nr + ni, ko * kc + ki], name="packedB")

    k = te.reduce_axis((0, K), "k")

    C = te.compute((M, N), lambda m, n: te.sum(
        packedA[m // mc, k // kc, tvm.tir.indexmod(m, mc) // mr, tvm.tir.indexmod(k, kc), tvm.tir.indexmod(tvm.tir.indexmod(m, mc), mr)] *
        packedB[n // nc, k // kc, tvm.tir.indexmod(n, nc) // nr, tvm.tir.indexmod(k, kc), tvm.tir.indexmod(tvm.tir.indexmod(n, nc), nr)], axis=k), name="C")
    s = te.create_schedule(C.op)

    print(tvm.lower(s, [A, B, C], simple_mode=True))

    lmo, lno, lmi, lni = s[C].tile(C.op.axis[0], C.op.axis[1], mc, nc)
    (k_axis,) = s[C].op.reduce_axis
    lko, lki = s[C].split(k_axis, factor=kc)
    lmro, lmri = s[C].split(lmi, factor=mr)
    lnro, lnri = s[C].split(lni, factor=nr)
    s[C].reorder(lno, lko, lmo, lnro, lmro, lmri, lki, lnri)
    # s[C].vectorize(lnri)

    s[packedB].compute_at(s[C], lko)
    s[packedA].compute_at(s[C], lmo)

    print(tvm.lower(s, [A, B, C], simple_mode=True))

    # Micro Kernel
    gemm_kernel = _gemm_micro_kernel(mr, kc, nr)
    s[C].tensorize(lmri, gemm_kernel)
    print(tvm.lower(s, [A, B, C], simple_mode=True))

    # # Test Correctness
    # func = tvm.build(s, [A, B, C], target=target, name="packed_multi")

    # a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), dev)
    # b = tvm.nd.array(numpy.random.rand(N, K).astype(dtype), dev)
    # c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), dev)
    # func(a, b, c)

    # tvm.testing.assert_allclose(c.numpy(), numpy.dot(
    #     a.numpy(), b.numpy().T), rtol=1e-5)
    # print('======>Correct')

    # evaluator = func.time_evaluator(func.entry_name, dev, number=100)
    # print("======>Time: %f" % evaluator(a, b, c).mean)

    # print('======>Done')


if __name__ == '__main__':
    main()

Have you figured it out? I an encountering the same problem. :frowning_face: