Hello,
I am trying to learn how to use TIR for scheduling operations and am running into some issues regarding how to describe the intrins. I want to tensorize a 16x16x16 matmul into the tensorized block. This is the intrin I defined:
def get_intrin_gemm_tir(
dim_i: int,
dim_k: int,
dim_j: int,
):
@T.prim_func
def my_matmul_desc(a: T.handle, b:T.handle, d:T.handle) -> None:
A = T.match_buffer(a, (dim_i, dim_k), dtype="int8", scope="local.spad")
B = T.match_buffer(b, (dim_k, dim_j), dtype="int8", scope="local.spad_w")
D = T.match_buffer(d, (dim_i, dim_j), dtype="int8", scope="local.acc")
with T.block("root"):
T.reads(A[0:16, 0:16], B[0:16, 0:16], D[0:16, 0:16])
T.writes(D[0:16, 0:16])
for i, j, k in T.grid(dim_i, dim_j, dim_k):
with T.block("res"):
vii, vjj, vkoi = T.axis.remap("SSR", [i, j, k])
D[vii, vjj] = D[vii, vjj] + A[vii, vkoi] * B[vjj, vkoi]
@T.prim_func
def my_matmul_impl(a: T.handle, b: T.handle, d:T.handle) -> None:
sai = T.int32()
sak = T.int32()
sbk = T.int32()
sbj = T.int32()
sdi = T.int32()
sdj = T.int32()
A = T.match_buffer(a, (dim_i, dim_k), strides=[sai, sak], dtype="int8", scope="local.spad")
B = T.match_buffer(b, (dim_k, dim_j), strides=[sbk, sbj], dtype="int8", scope="local.spad_w")
D = T.match_buffer(d, (dim_i, dim_j), strides=[sdi, sdj], dtype="int8", scope="local.acc")
with T.block("root"):
T.reads(A[0:16, 0:16], B[0:16, 0:16], D[0:16, 0:16])
T.writes(D[0:16, 0:16])
T.evaluate(
T.call_extern("test", dtype="" )
)
return my_matmul_desc, my_matmul_impl
But when I run this on my schedule I get this error:
Error message: The stmt tir.Block#0 doesn't match the tensor intrin
The pattern attempting to be matched:
with T.block("res", no_realize=True):
v_i_i = T.axis.spatial(16)
v_j_i = T.axis.spatial(16)
v_k_o_i = T.axis.reduce(16)
res_local_acc = T.Buffer((128, 64), "int8", scope="local.acc")
v_i_o = T.int32()
v_j_o = T.int32()
a_in_local_spad = T.Buffer((128, 256), "int8", scope="local.spad")
v_k_o_o = T.int32()
b_in_local_spad_w = T.Buffer((256, 64), "int8", scope="local.spad_w")
bias_in_local_acc = T.Buffer((128, 64), "int32", scope="local.acc")
T.reads(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i], a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i], b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i], bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
T.writes(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] = res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] + (a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i] * b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i] + T.Cast("int8", bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i]))
Does not match the tensorize description:
with T.block("root", no_realize=True):
A = T.Buffer((16, 16), "int8", scope="local.spad")
B = T.Buffer((16, 16), "int8", scope="local.spad_w")
D = T.Buffer((16, 16), "int8", scope="local.acc")
T.reads(A[0:16, 0:16], B[0:16, 0:16], D[0:16, 0:16])
T.writes(D[0:16, 0:16])
for i, j, k in T.grid(16, 16, 16):
with T.block("res"):
vii, vjj, vkoi = T.axis.remap("SSR", [i, j, k])
T.reads(D[vii, vjj], A[vii, vkoi], B[vjj, vkoi])
T.writes(D[vii, vjj])
D[vii, vjj] = D[vii, vjj] + A[vii, vkoi] * B[vjj, vkoi]
CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]=T.Range(v_i_o * 16 + v_i_i, 1) vs rhs->region[i]=T.Range(0, 16)
BlockNode write buffers do not match: op->writes=[res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i]] vs rhs->writes=[D[0:16, 0:16]]
I don’t know why the buffers like res_local_acc are so large here, if I look at the block that I want to tensorize over it looks like this:
for k_o_1, j_1, i_1 in T.grid(16, 16, 16):
with T.block("res"):
v_i_i, v_j_i, v_k_o_i = T.axis.remap("SSR", [i_1, j_1, k_o_1])
T.reads(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i], a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i], b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i], bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
T.writes(res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i])
res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] = res_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i] + (a_in_local_spad[v_i_o * 16 + v_i_i, v_k_o_o * 16 + v_k_o_i] * b_in_local_spad_w[v_k_o_o * 16 + v_k_o_i, v_j_o * 16 + v_j_i] + T.Cast("int8", bias_in_local_acc[v_i_o * 16 + v_i_i, v_j_o * 16 + v_j_i]))
It does only use 16x16 elements from each buffer. What am I doing wrong here?