Thank you for your interest.
Tensorize in TensorIR is completely different from the TE ones. In TensorIR, we use two functions (desc_func and intrin_func) to define an intrinsic. Here would be an example of intrinsic (Note that TensorIR is still WIP, so the API may be changed).
@tvm.hybrid.script
def desc_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [16, 16])
B = tir.match_buffer(b, [16, 16])
C = tir.match_buffer(c, [16, 16])
with tir.block([16, 16, tir.reduce_axis(0, 16)], "root") as [vi, vj, vk]:
for i, j, k in tir.grid(16, 16, 16):
with tir.block([16, 16, tir.reduce_axis(0, 16)], "update") as [vii, vjj, vkk]:
tir.bind(vii, vi + i)
tir.bind(vjj, vj + j)
tir.bind(vkk, vk + k)
C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
@tvm.hybrid.script
def intrin_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [16, 16])
B = tir.match_buffer(b, [16, 16])
C = tir.match_buffer(c, [16, 16])
with tir.block([16, 16, tir.reduce_axis(0, 16)], "root") as [vi, vj, vk]:
tir.evaluate(tir.tvm_mma_sync(C.data, C.elem_offset // 256,
A.data, A.elem_offset // 256,
B.data, B.elem_offset // 256,
C.data, C.elem_offset // 256,
dtype="handle"))
Tensorize will match the sub-AST(usually is a block) with the desc_func
, and then replace by intrin_func
.
TensorIR is in the schedule level and has no coupling with low-level passes. However, we can directly schedule each loop directly and add primitives as you want. 