Hi,
I am trying to apply tvm.tir.Schedule.Tensorize on an IR Module. However, I wasn’t successful. Later, I have tried to run example provided in Tensorize documentation (tvm.tir — tvm 0.9.dev0 documentation) in order to discover what I am doing wrong. Yet, I couldn’t get tensorize example working either. Here are the script that I try to test given example and the error message that I receive. Can you help me to find out what am I missing since the provided example is not working too.
from tvm import testing
testing.utils.install_request_hook(depth=3)
import tvm
from tvm.script import tir as T
@T.prim_func
def before_tensorize(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
C: T.Buffer[(128, 128), "float32"],
) -> None:
# body
# with T.block("root")
for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16):
with T.block("update"):
vi = T.axis.spatial(128, i_0 * 16 + i_1)
vj = T.axis.spatial(128, j_0 * 16 + j_1)
vk = T.axis.reduce(128, k_0 * 16 + k_1)
T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
with T.block("root"):
T.reads(C[0: 16, 0: 16], A[0: 16, 0: 16], B[0: 16, 0: 16])
T.writes(C[0: 16, 0: 16])
for i, j, k in T.grid(16, 16, 16):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
with T.block("root"):
T.reads(C[0: 16, 0: 16], A[0: 16, 0: 16], B[0: 16, 0: 16])
T.writes(C[0: 16, 0: 16])
T.evaluate(
T.tvm_mma_sync(
C.data,
C.elem_offset // 256,
A.data,
A.elem_offset // 256,
B.data,
B.elem_offset // 256,
C.data,
C.elem_offset // 256,
dtype="handle",
)
)
tvm.tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
sch = tvm.tir.Schedule(before_tensorize)
update = sch.get_block("update")
_, _, _, i1, _, _ = sch.get_loops(update)
sch.tensorize(i1, "test_mma_intrin")
Error Message :
Traceback (most recent call last):
File "/opt/pycharm-community-2021.2.1/plugins/python-ce/helpers/pydev/pydevd.py", line 1483, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/opt/pycharm-community-2021.2.1/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/root/dockerhome/tvm-s4e/python/tvm/relay/backend/contrib/uma/rb_npu/test.py", line 122, in <module>
sch.tensorize(i1, "test_mma_intrin")
File "/tvm/python/tvm/tir/schedule/_type_checker.py", line 237, in wrap
return func(*args, **kwargs)
File "/tvm/python/tvm/tir/schedule/schedule.py", line 2192, in tensorize
_ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member
File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
3: TVMFuncCall
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String)#13}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String)#13}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
1: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&)
0: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&) [clone .cold]
ScheduleError: An error occurred in the schedule primitive 'tensorize'.
The IR with diagnostic is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
# body
# with T.block("root")
for i_0, j_0, k_0 in T.grid(8, 8, 8):
with T.block("update_o"):
vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16], A[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], B[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
with T.init():
for i_1, j_1 in T.grid(16, 16):
with T.block("update_init"):
vi_init, vj_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C[vi_o * 16 + vi_init, vj_o * 16 + vj_init])
C[vi_o * 16 + vi_init, vj_o * 16 + vj_init] = T.float32(0)
for i_1, j_1, k_1 in T.grid(16, 16, 16):
# tir.Block#0
with T.block("update"):
^^^^^^^^^^^^^^^^^^^^^^^
vi, vj, vk = T.axis.remap("SSR", [i_1, j_1, k_1])
T.reads(C[vi_o * 16 + vi, vj_o * 16 + vj], C[vi_o * 16 + vi, vj_o * 16 + vj], A[vi_o * 16 + vi, vk_o * 16 + vk], B[vj_o * 16 + vj, vk_o * 16 + vk])
T.writes(C[vi_o * 16 + vi, vj_o * 16 + vj])
C[vi_o * 16 + vi, vj_o * 16 + vj] = C[vi_o * 16 + vi, vj_o * 16 + vj] + A[vi_o * 16 + vi, vk_o * 16 + vk] * B[vj_o * 16 + vj, vk_o * 16 + vk]
Error message: The stmt tir.Block#0 doesn't match the tensor intrin
block update(iter_var(vi, range(min=0, ext=16)), iter_var(vj, range(min=0, ext=16)), iter_var(vk, range(min=0, ext=16))) {
reads([C[vi, vj], A[vi, vk], B[vj, vk]])
writes([C[vi, vj]])
C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))
}
Process finished with exit code 1