[TensorIR] Error with tensorization:The stmt tir.Block#0 doesn't match the tensor intrin

Here is the matrix operation that I want to tensize,

for ax5_1_0, ax0, ax1, ax2, ax5_1_1 in T.grid(4, 10, 10, 32, 32):
                            with T.block("Conv"):
                                v_nn = T.axis.spatial(128, ax0_0 * 4 + ax0_1)
                                v_yy, v_xx = T.axis.remap("SS", [ax0, ax1])
                                v_ff = T.axis.spatial(128, ax3_0 * 32 + ax2)
                                v_ry, v_rx = T.axis.remap("RR", [ax3, ax4])
                                v_rc = T.axis.reduce(256, ax5_0 * 128 + ax5_1_0 * 32 + ax5_1_1)
                                T.reads(A_shared[v_nn, v_yy + v_ry - 1, v_xx + v_rx - 1, v_rc], W_shared[v_ry, v_rx, v_rc, v_ff])
                                T.writes(Conv_shared[v_nn, v_yy, v_xx, v_ff])
                                with T.init():
                                    Conv_shared[v_nn, v_yy, v_xx, v_ff] = T.float16(0)
                                Conv_shared[v_nn, v_yy, v_xx, v_ff] = Conv_shared[v_nn, v_yy, v_xx, v_ff] + T.if_then_else(1 <= v_yy + v_ry and v_yy + v_ry < 11 and 1 <= v_xx + v_rx and v_xx + v_rx < 11, A_shared[v_nn, v_yy + v_ry - 1, v_xx + v_rx - 1, v_rc], T.float16(0), dtype="float16") * W_shared[v_ry, v_rx, v_rc, v_ff]

I copied tir.if_then_else in intrin func to describe the calculation process. This is part of the intrin func:

v_ry = tir.IterVar((0,3),"v_ry",2)
v_rx = tir.IterVar((0,3),"v_rx",2)

@T.prim_func
def gemm_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(
        a, (h, w, batch_in_channel), in_dtype, align=32, offset_factor=16, scope="shared"
    )
    B = T.match_buffer(
        b,
        (batch_in_channel, batch_out_channel),
        in_dtype,
        align=32,
        offset_factor=16,
        scope="shared",
    )
    C = T.match_buffer(
        c, (h, w, batch_out_channel), out_dtype, align=32, offset_factor=16, scope="shared"
    )

    with T.block("root"):
        T.reads(A[0:h, 0:w, 0:batch_in_channel], B[0:batch_in_channel, 0:batch_out_channel])
        T.writes(C[0:h, 0:w, 0:batch_out_channel])
        for i, j, k, l in T.grid(h, w, batch_out_channel, batch_in_channel):
            with T.block(""):
                vii, vjj, vkk, vll = T.axis.remap("SSSR", [i, j, k, l])

                C[vii, vjj, vkk] = C[vii, vjj, vkk] +  tvm.tir.if_then_else(
                tvm.tir.all(1 <= (vii + v_ry), (vii + v_ry) < 11, 1 <= (vjj + v_rx), (vjj + v_rx) < 11),
                A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll],
                tvm.tir.const(0.0, "float16"),
                ) * (B[vll, vkk])

It reported an error:

Error message: The stmt tir.Block#0 doesn't match the tensor intrin
The pattern attempting to be matched:
block Conv(iter_var(v_yy_i, range(min=0, ext=10)), iter_var(v_xx_i, range(min=0, ext=10)), iter_var(v_ff_i, range(min=0, ext=32)), iter_var(v_rc_i, range(min=0, ext=32))) {
  reads([Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)], A_shared[v_nn, ((v_yy_i + v_ry) - 1), ((v_xx_i + v_rx) - 1), ((v_rc_o*32) + v_rc_i)], W_shared[v_ry, v_rx, ((v_rc_o*32) + v_rc_i), ((v_ff_o*32) + v_ff_i)]])
  writes([Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)]])
  Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)] = (Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)] + (tir.if_then_else(((((1 <= (v_yy_i + v_ry)) && ((v_yy_i + v_ry) < 11)) && (1 <= (v_xx_i + v_rx))) && ((v_xx_i + v_rx) < 11)), A_shared[v_nn, ((v_yy_i + v_ry) - 1), ((v_xx_i + v_rx) - 1), ((v_rc_o*32) + v_rc_i)], 0h)*W_shared[v_ry, v_rx, ((v_rc_o*32) + v_rc_i), ((v_ff_o*32) + v_ff_i)]))
}

Does not match the tensorize description:
block (iter_var(vii, range(min=0, ext=10)), iter_var(vjj, range(min=0, ext=10)), iter_var(vkk, range(min=0, ext=32)), iter_var(vll, range(min=0, ext=32))) {
  reads([C[vii, vjj, vkk], A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll], B[vll, vkk]])
  writes([C[vii, vjj, vkk]])
  C[vii, vjj, vkk] = (C[vii, vjj, vkk] + (tir.if_then_else(((((1 <= (vii + v_ry)) && ((vii + v_ry) < 11)) && (1 <= (vjj + v_rx))) && ((vjj + v_rx) < 11)), A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll], 0h)*B[vll, vkk]))
}
CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]=range(min=((v_yy_i + v_ry) - 1), ext=1) vs rhs->region[i]=range(min=((vii + v_ry) - 1), ext=1)
BlockNode read buffers regions do not match: op->reads=[Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)], A_shared[v_nn, ((v_yy_i + v_ry) - 1), ((v_xx_i + v_rx) - 1), ((v_rc_o*32) + v_rc_i)], W_shared[v_ry, v_rx, ((v_rc_o*32) + v_rc_i), ((v_ff_o*32) + v_ff_i)]] vs rhs->reads=[C[vii, vjj, vkk], A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll], B[vll, vkk]]

I don’t understand why it doesn’t work. Could anyone explain it to me? Really appreciate your help!

Most of the time, we use the tensorize primitive to leverage the accelerator’s instructions like tensorcore.

Let’s take tensorcore as an example, there are two different ways to tensorize conv2d. The first one is to apply it to direct conv2d. You can learn how to do this beyond this tutorial. However, this process cannot be implemented on the upstream’s tensor IR, as there are some iter extent related issues, and direct conv is not friendly as well. The most common way to do tensorization is to convert the conv2d into a general gemm problem using im2col. Here is an example of how to do an implementation with tensor ir.

Thanks for your reply. As you said, I want to use the Tensorize primitive to take advantage of my own accelerator instructions, so img2col may not be the way to go for me. I have used TE to solve the if_then_else problem in a similar way.This is part of the TE intrin func:

#rh = Tensor.op.axis
def intrin_load_matrix(rh):
    n = in_size 
    m = in_channel
    SA = te.placeholder((n, m), name="SA", dtype="float16")
    BA = tvm.tir.decl_buffer(SA.shape, SA.dtype, scope="global", data_alignment=32, offset_factor=16)
    SC = te.compute(( n + 2 * pad, m), lambda j, k: tvm.tir.if_then_else(
        tvm.tir.all(rh >= pad, rh - pad < in_size, j >= pad, j - pad < in_size),
        SA[j - pad, k],
        tvm.tir.const(0.0, "float16"),
        ), 
        name="SC")
    BC = tvm.tir.decl_buffer(SC.shape, SC.dtype, scope="shared", data_alignment=32, offset_factor=16)

What is the difference between TE and TIR? Why is it not working in TIR?