Here is the matrix operation that I want to tensize,
for ax5_1_0, ax0, ax1, ax2, ax5_1_1 in T.grid(4, 10, 10, 32, 32):
with T.block("Conv"):
v_nn = T.axis.spatial(128, ax0_0 * 4 + ax0_1)
v_yy, v_xx = T.axis.remap("SS", [ax0, ax1])
v_ff = T.axis.spatial(128, ax3_0 * 32 + ax2)
v_ry, v_rx = T.axis.remap("RR", [ax3, ax4])
v_rc = T.axis.reduce(256, ax5_0 * 128 + ax5_1_0 * 32 + ax5_1_1)
T.reads(A_shared[v_nn, v_yy + v_ry - 1, v_xx + v_rx - 1, v_rc], W_shared[v_ry, v_rx, v_rc, v_ff])
T.writes(Conv_shared[v_nn, v_yy, v_xx, v_ff])
with T.init():
Conv_shared[v_nn, v_yy, v_xx, v_ff] = T.float16(0)
Conv_shared[v_nn, v_yy, v_xx, v_ff] = Conv_shared[v_nn, v_yy, v_xx, v_ff] + T.if_then_else(1 <= v_yy + v_ry and v_yy + v_ry < 11 and 1 <= v_xx + v_rx and v_xx + v_rx < 11, A_shared[v_nn, v_yy + v_ry - 1, v_xx + v_rx - 1, v_rc], T.float16(0), dtype="float16") * W_shared[v_ry, v_rx, v_rc, v_ff]
I copied tir.if_then_else in intrin func to describe the calculation process. This is part of the intrin func:
v_ry = tir.IterVar((0,3),"v_ry",2)
v_rx = tir.IterVar((0,3),"v_rx",2)
@T.prim_func
def gemm_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (h, w, batch_in_channel), in_dtype, align=32, offset_factor=16, scope="shared"
)
B = T.match_buffer(
b,
(batch_in_channel, batch_out_channel),
in_dtype,
align=32,
offset_factor=16,
scope="shared",
)
C = T.match_buffer(
c, (h, w, batch_out_channel), out_dtype, align=32, offset_factor=16, scope="shared"
)
with T.block("root"):
T.reads(A[0:h, 0:w, 0:batch_in_channel], B[0:batch_in_channel, 0:batch_out_channel])
T.writes(C[0:h, 0:w, 0:batch_out_channel])
for i, j, k, l in T.grid(h, w, batch_out_channel, batch_in_channel):
with T.block(""):
vii, vjj, vkk, vll = T.axis.remap("SSSR", [i, j, k, l])
C[vii, vjj, vkk] = C[vii, vjj, vkk] + tvm.tir.if_then_else(
tvm.tir.all(1 <= (vii + v_ry), (vii + v_ry) < 11, 1 <= (vjj + v_rx), (vjj + v_rx) < 11),
A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll],
tvm.tir.const(0.0, "float16"),
) * (B[vll, vkk])
It reported an error:
Error message: The stmt tir.Block#0 doesn't match the tensor intrin
The pattern attempting to be matched:
block Conv(iter_var(v_yy_i, range(min=0, ext=10)), iter_var(v_xx_i, range(min=0, ext=10)), iter_var(v_ff_i, range(min=0, ext=32)), iter_var(v_rc_i, range(min=0, ext=32))) {
reads([Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)], A_shared[v_nn, ((v_yy_i + v_ry) - 1), ((v_xx_i + v_rx) - 1), ((v_rc_o*32) + v_rc_i)], W_shared[v_ry, v_rx, ((v_rc_o*32) + v_rc_i), ((v_ff_o*32) + v_ff_i)]])
writes([Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)]])
Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)] = (Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)] + (tir.if_then_else(((((1 <= (v_yy_i + v_ry)) && ((v_yy_i + v_ry) < 11)) && (1 <= (v_xx_i + v_rx))) && ((v_xx_i + v_rx) < 11)), A_shared[v_nn, ((v_yy_i + v_ry) - 1), ((v_xx_i + v_rx) - 1), ((v_rc_o*32) + v_rc_i)], 0h)*W_shared[v_ry, v_rx, ((v_rc_o*32) + v_rc_i), ((v_ff_o*32) + v_ff_i)]))
}
Does not match the tensorize description:
block (iter_var(vii, range(min=0, ext=10)), iter_var(vjj, range(min=0, ext=10)), iter_var(vkk, range(min=0, ext=32)), iter_var(vll, range(min=0, ext=32))) {
reads([C[vii, vjj, vkk], A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll], B[vll, vkk]])
writes([C[vii, vjj, vkk]])
C[vii, vjj, vkk] = (C[vii, vjj, vkk] + (tir.if_then_else(((((1 <= (vii + v_ry)) && ((vii + v_ry) < 11)) && (1 <= (vjj + v_rx))) && ((vjj + v_rx) < 11)), A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll], 0h)*B[vll, vkk]))
}
CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]=range(min=((v_yy_i + v_ry) - 1), ext=1) vs rhs->region[i]=range(min=((vii + v_ry) - 1), ext=1)
BlockNode read buffers regions do not match: op->reads=[Conv_shared[v_nn, v_yy_i, v_xx_i, ((v_ff_o*32) + v_ff_i)], A_shared[v_nn, ((v_yy_i + v_ry) - 1), ((v_xx_i + v_rx) - 1), ((v_rc_o*32) + v_rc_i)], W_shared[v_ry, v_rx, ((v_rc_o*32) + v_rc_i), ((v_ff_o*32) + v_ff_i)]] vs rhs->reads=[C[vii, vjj, vkk], A[((vii + v_ry) - 1), ((vjj + v_rx) - 1), vll], B[vll, vkk]]
I don’t understand why it doesn’t work. Could anyone explain it to me? Really appreciate your help!