I’ve tried simply moving the iteration variable into the intrin, but that leads to issues when I try to perform the tensorization. Here is the error message:
The pattern attempting to be matched:
for i_1 in range(16):
with T.block("res_o_update"):
v_i_o_i = T.axis.spatial(16, i_1)
res_local_acc = T.Buffer((512, 256), "int32", scope="local.acc")
v_i_o_o = T.int32()
v_j_o = T.int32()
a_in_local_spad = T.Buffer((512, 768), "int8", scope="local.spad")
v_k_o = T.int32()
b_in_local_spad_w = T.Buffer((768, 256), "int8", scope="local.spad_w")
T.reads(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4:v_i_o_o * 64 + v_i_o_i * 4 + 4, v_j_o * 4:v_j_o * 4 + 4], a_in_local_spad[v_i_o_o * 64 + v_i_o_i * 4:v_i_o_o * 64 + v_i_o_i * 4 + 4, v_k_o * 4:v_k_o * 4 + 4], b_in_local_spad_w[v_k_o * 4:v_k_o * 4 + 4, v_j_o * 4:v_j_o * 4 + 4])
T.writes(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4:v_i_o_o * 64 + v_i_o_i * 4 + 4, v_j_o * 4:v_j_o * 4 + 4])
for k_2, j_2, i_2 in T.grid(4, 4, 4):
with T.block("res"):
v_i_i, v_j_i, v_k_i = T.axis.remap("SSR", [i_2, j_2, k_2])
T.reads(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i], a_in_local_spad[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_k_o * 4 + v_k_i], b_in_local_spad_w[v_k_o * 4 + v_k_i, v_j_o * 4 + v_j_i])
T.writes(res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i])
res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i] = res_local_acc[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_j_o * 4 + v_j_i] + T.Cast("int32", a_in_local_spad[v_i_o_o * 64 + v_i_o_i * 4 + v_i_i, v_k_o * 4 + v_k_i]) * T.Cast("int32", b_in_local_spad_w[v_k_o * 4 + v_k_i, v_j_o * 4 + v_j_i])
Does not match the tensorize description:
for io, k, j, i in T.grid(16, 4, 4, 4):
with T.block(""):
vio, vii, vjj, vkk = T.axis.remap("SSSR", [io, i, j, k])
C = T.Buffer((64, 4), "int32", scope="local.acc", offset_factor=1)
A = T.Buffer((64, 4), "int8", scope="local.spad", offset_factor=1)
B = T.Buffer((4, 4), "int8", scope="local.spad_w", offset_factor=1)
T.reads(C[vii + vio, vjj + vio], A[vii + vio, vkk], B[vkk, vjj + vio])
T.writes(C[vii + vio, vjj + vio])
C[vii + vio, vjj + vio] = C[vii + vio, vjj + vio] + T.Cast("int32", A[vii + vio, vkk]) * T.Cast("int32", B[vkk, vjj + vio])
And the intrin is described like this:
@T.prim_func
def my_matmul_desc(a: T.handle, b:T.handle, c:T.handle, ) -> None:
A = T.match_buffer(a, (dim_io, dim_k), "int8", offset_factor=1, scope="local.spad")
B = T.match_buffer(b, (dim_ko, dim_j), "int8", offset_factor=1, scope="local.spad_w")
C = T.match_buffer(c, (dim_io, dim_j), "int32", offset_factor=1, scope="local.acc")
with T.block("root"):
T.reads(C[0:dim_io, 0:dim_j],
A[0:dim_io, 0:dim_k],
B[0:dim_ko, 0:dim_j])
T.writes(C[0:dim_io, 0:dim_j])
for io in T.serial(0, io_extent):
for k, j, i in T.grid(dim_k, dim_j, dim_i):
with T.block(""):
vio, vii, vjj, vkk = T.axis.remap("SSSR", [io, i, j, k])
C[vii+vio*dim_i, vjj] = C[vii+vio*dim_i, vjj] + T.cast(A[vii+vio*dim_i, vkk], "int32") * T.cast(B[vkk, vjj], "int32")
What am I doing wrong here in the description of the intrin? I noticed that some variables like v_i_o_o are automatically inserted, but I think they are always zero.
I think the main issue is that I don’t quite get how I need to descibe the buffer matching part.