def get_take_2d_intrin(M: T.int64, D: T.int64, L: T.int64, dtype=“float32”, idx_dtype=“int32”): @T.prim_func def desc(a: T.handle, indices: T.handle, out: T.handle): A = T.match_buffer(a, (L, D), dtype=dtype) xx = T.match_buffer(indices, (M,), dtype=idx_dtype) T_take = T.match_buffer(out, (M, D), dtype=dtype)
with T.block("root"):
T.reads(A[T.int64(0):L, T.int64(0):D], xx[T.int64(0):M])
T.writes(T_take[T.int64(0):M, T.int64(0):D])
for ax0, ax1 in T.grid(M, D):
with T.block("T_take"):
v0_i = T.axis.spatial(M, ax0)
v1_i = T.axis.spatial(D, ax1)
T.reads(A[xx[v0_i], v1_i], xx[v0_i])
T.writes(T_take[v0_i, v1_i])
T_take[v0_i, v1_i] = A[xx[v0_i], v1_i]
@T.prim_func
def impl(a: T.handle, indices: T.handle, out: T.handle):
A = T.match_buffer(a, (T.int64(L), T.int64(D)), dtype=dtype, offset_factor=1)
xx = T.match_buffer(indices, (T.int64(M),), dtype=idx_dtype, offset_factor=1)
T_take = T.match_buffer(out, (T.int64(M), T.int64(D)), dtype=dtype, offset_factor=1)
with T.block("take_impl"):
T.reads(A[0:T.int64(L), 0:T.int64(D)], xx[0:T.int64(M)])
T.writes(T_take[0:T.int64(M), 0:T.int64(D)])
T.call_packed(
"take_m_d",
T_take.data, T_take.elem_offset,
A.data, A.elem_offset,
xx.data, xx.elem_offset,
T.int64(M), T.int64(D), T.int64(L),
dtype, idx_dtype
)
return desc, impl
@T.prim_func(private=True) def main(A: T.Buffer((T.int64(10), T.int64(96)), “float32”), xx: T.Buffer((T.int64(5),), “int64”), T_take: T.Buffer((T.int64(5), T.int64(96)), “float32”)): T.func_attr({“op_pattern”: 8, “tir.noalias”: T.bool(True)}) # with T.block(“root”): for ax0, ax1 in T.grid(T.int64(5), T.int64(96)): with T.block(“T_take”): v0, v1 = T.axis.remap(“SS”, [ax0, ax1]) T.reads(A[xx[v0], v1], xx[v0]) T.writes(T_take[v0, v1]) T_take[v0, v1] = A[xx[v0], v1]
My tensorize code and IR are as above, and I encountered the following error when running it:
[10:54:57] /home/zte/TVM/tvm/src/script/printer/ir/./../utils.h:46: Warning: TVMScript printer falls back to the legacy ReprPrinter with the error:
[10:54:57] /home/zte/TVM/tvm/src/script/printer/tir/buffer.cc:319: IndexError: Buffer is not defined in the environment: xx
Stack trace:
0: operator()
at /home/zte/TVM/tvm/src/script/printer/tir/buffer.cc:319
1: tvm::script:
:Doc tvm::script:
:IRDocsifierFunctor<tvm::script:
:Doc, tvm::ObjectPath, tvm::script:
:IRDocsifier>::operator()tvm::runtime::ObjectRef(tvm::runtime::String const&, tvm::runtime::ObjectRef, tvm::ObjectPath, tvm::script:
:IRDocsifier) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier_functor.h:71
2: tvm::script:
:ExprDoc tvm::script:
:IRDocsifierNode::AsDoctvm::script::printer::ExprDoc(tvm::runtime::ObjectRef const&, tvm::ObjectPath const&) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier.h:315
3: operator()
at /home/zte/TVM/tvm/src/script/printer/tir/buffer.cc:294
4: tvm::script:
:Doc tvm::script:
:IRDocsifierFunctor<tvm::script:
:Doc, tvm::ObjectPath, tvm::script:
:IRDocsifier>::operator()tvm::runtime::ObjectRef(tvm::runtime::String const&, tvm::runtime::ObjectRef, tvm::ObjectPath, tvm::script:
:IRDocsifier) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier_functor.h:71
5: tvm::script:
:ExprDoc tvm::script:
:IRDocsifierNode::AsDoctvm::script::printer::ExprDoc(tvm::runtime::ObjectRef const&, tvm::ObjectPath const&) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier.h:315
6: operator()
at /home/zte/TVM/tvm/src/script/printer/ir/ir.cc:161
7: tvm::script:
:Doc tvm::script:
:IRDocsifierFunctor<tvm::script:
:Doc, tvm::ObjectPath, tvm::script:
:IRDocsifier>::operator()tvm::runtime::ObjectRef(tvm::runtime::String const&, tvm::runtime::ObjectRef, tvm::ObjectPath, tvm::script:
:IRDocsifier) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier_functor.h:68
8: tvm::script:
:Doc tvm::script:
:IRDocsifierNode::AsDoctvm::script::printer::Doc(tvm::runtime::ObjectRef const&, tvm::ObjectPath const&) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier.h:315
9: tvm::script:
:Docsify[abi:cxx11](tvm::runtime::ObjectRef const&, tvm::script:
:IRDocsifier const&, tvm::script:
:Frame const&, tvm::PrinterConfig const&)
at /home/zte/TVM/tvm/src/script/printer/ir/./../utils.h:63
10: tvm::script:
:ReprPrintIR[abi:cxx11](tvm::runtime::ObjectRef const&, tvm::PrinterConfig const&)
at /home/zte/TVM/tvm/src/script/printer/ir/./utils.h:67
11: tvm::TVMScriptPrinter::Script[abi:cxx11](tvm::runtime::ObjectRef const&, tvm::runtime::Optionaltvm::PrinterConfig const&)
at /home/zte/TVM/tvm/src/node/script_printer.cc:37
12: tvm::script:
:RedirectedReprPrinterMethod(tvm::runtime::ObjectRef const&, tvm::ReprPrinter*)
at /home/zte/TVM/tvm/src/script/printer/ir/./../utils.h:43
13: tvm::runtime::operator<<(std::ostream&, tvm::runtime::ObjectRef const&)
at /home/zte/TVM/tvm/include/tvm/node/repr_printer.h:98
14: tvm::tir::TensorizeComparator::CompareBufferRegion(tvm::tir::BufferRegion const&, tvm::tir::BufferRegion const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:531
15: bool tvm::tir::TensorizeComparator::CompareArray<tvm::tir::BufferRegion, tvm::tir::TensorizeComparator, bool (tvm::tir::BufferRegion const&, tvm::tir::BufferRegion const&)>(tvm::runtime::Array<tvm::tir::BufferRegion, std::enable_if<std::is_base_of<tvm::runtime::ObjectRef, tvm::tir::BufferRegion>::value, void>::type> const&, tvm::runtime::Array<tvm::tir::BufferRegion, std::enable_if<std::is_base_of<tvm::runtime::ObjectRef, tvm::tir::BufferRegion>::value, void>::type> const&, bool (tvm::tir::TensorizeComparator::)(tvm::tir::BufferRegion const&, tvm::tir::BufferRegion const&))
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:588
16: tvm::tir::TensorizeComparator::VisitStmt_(tvm::tir::BlockNode const, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:239
17: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
18: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
19: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
20: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
21: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
22: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
23: tvm::tir::Tensorize(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, tvm::tir::TensorIntrin const&, bool)
at /home/zte/TVM/tvm/src/tir/schedule/primitive/blockize_tensorize.cc:780
24: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
at /home/zte/TVM/tvm/src/tir/schedule/concrete_schedule.cc:898
25: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
at /home/zte/TVM/tvm/src/tir/schedule/traced_schedule.cc:615
26: operator()
at /home/zte/TVM/tvm/src/tir/schedule/schedule.cc:248
[10:54:57] /home/zte/TVM/tvm/src/script/printer/ir/./../utils.h:46: Warning: TVMScript printer falls back to the legacy ReprPrinter with the error:
[10:54:57] /home/zte/TVM/tvm/src/script/printer/tir/buffer.cc:319: IndexError: Buffer is not defined in the environment: xx
Stack trace:
0: operator()
at /home/zte/TVM/tvm/src/script/printer/tir/buffer.cc:319
1: tvm::script:
:Doc tvm::script:
:IRDocsifierFunctor<tvm::script:
:Doc, tvm::ObjectPath, tvm::script:
:IRDocsifier>::operator()tvm::runtime::ObjectRef(tvm::runtime::String const&, tvm::runtime::ObjectRef, tvm::ObjectPath, tvm::script:
:IRDocsifier) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier_functor.h:71
2: tvm::script:
:ExprDoc tvm::script:
:IRDocsifierNode::AsDoctvm::script::printer::ExprDoc(tvm::runtime::ObjectRef const&, tvm::ObjectPath const&) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier.h:315
3: operator()
at /home/zte/TVM/tvm/src/script/printer/tir/buffer.cc:294
4: tvm::script:
:Doc tvm::script:
:IRDocsifierFunctor<tvm::script:
:Doc, tvm::ObjectPath, tvm::script:
:IRDocsifier>::operator()tvm::runtime::ObjectRef(tvm::runtime::String const&, tvm::runtime::ObjectRef, tvm::ObjectPath, tvm::script:
:IRDocsifier) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier_functor.h:71
5: tvm::script:
:ExprDoc tvm::script:
:IRDocsifierNode::AsDoctvm::script::printer::ExprDoc(tvm::runtime::ObjectRef const&, tvm::ObjectPath const&) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier.h:315
6: operator()
at /home/zte/TVM/tvm/src/script/printer/ir/ir.cc:161
7: tvm::script:
:Doc tvm::script:
:IRDocsifierFunctor<tvm::script:
:Doc, tvm::ObjectPath, tvm::script:
:IRDocsifier>::operator()tvm::runtime::ObjectRef(tvm::runtime::String const&, tvm::runtime::ObjectRef, tvm::ObjectPath, tvm::script:
:IRDocsifier) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier_functor.h:68
8: tvm::script:
:Doc tvm::script:
:IRDocsifierNode::AsDoctvm::script::printer::Doc(tvm::runtime::ObjectRef const&, tvm::ObjectPath const&) const
at /home/zte/TVM/tvm/include/tvm/script/printer/ir_docsifier.h:315
9: tvm::script:
:Docsify[abi:cxx11](tvm::runtime::ObjectRef const&, tvm::script:
:IRDocsifier const&, tvm::script:
:Frame const&, tvm::PrinterConfig const&)
at /home/zte/TVM/tvm/src/script/printer/ir/./../utils.h:63
10: tvm::script:
:ReprPrintIR[abi:cxx11](tvm::runtime::ObjectRef const&, tvm::PrinterConfig const&)
at /home/zte/TVM/tvm/src/script/printer/ir/./utils.h:67
11: tvm::TVMScriptPrinter::Script[abi:cxx11](tvm::runtime::ObjectRef const&, tvm::runtime::Optionaltvm::PrinterConfig const&)
at /home/zte/TVM/tvm/src/node/script_printer.cc:37
12: tvm::script:
:RedirectedReprPrinterMethod(tvm::runtime::ObjectRef const&, tvm::ReprPrinter*)
at /home/zte/TVM/tvm/src/script/printer/ir/./../utils.h:43
13: tvm::runtime::operator<<(std::ostream&, tvm::runtime::ObjectRef const&)
at /home/zte/TVM/tvm/include/tvm/node/repr_printer.h:98
14: tvm::tir::TensorizeComparator::CompareBufferRegion(tvm::tir::BufferRegion const&, tvm::tir::BufferRegion const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:531
15: bool tvm::tir::TensorizeComparator::CompareArray<tvm::tir::BufferRegion, tvm::tir::TensorizeComparator, bool (tvm::tir::BufferRegion const&, tvm::tir::BufferRegion const&)>(tvm::runtime::Array<tvm::tir::BufferRegion, std::enable_if<std::is_base_of<tvm::runtime::ObjectRef, tvm::tir::BufferRegion>::value, void>::type> const&, tvm::runtime::Array<tvm::tir::BufferRegion, std::enable_if<std::is_base_of<tvm::runtime::ObjectRef, tvm::tir::BufferRegion>::value, void>::type> const&, bool (tvm::tir::TensorizeComparator::)(tvm::tir::BufferRegion const&, tvm::tir::BufferRegion const&))
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:588
16: tvm::tir::TensorizeComparator::VisitStmt_(tvm::tir::BlockNode const, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:239
17: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
18: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
19: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
20: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
21: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
22: tvm::tir::TensorizeComparator::VisitStmt(tvm::tir::Stmt const&, tvm::tir::Stmt const&)
at /home/zte/TVM/tvm/src/tir/schedule/ir_comparator.cc:69
23: tvm::tir::Tensorize(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, tvm::tir::TensorIntrin const&, bool)
at /home/zte/TVM/tvm/src/tir/schedule/primitive/blockize_tensorize.cc:780
24: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
at /home/zte/TVM/tvm/src/tir/schedule/concrete_schedule.cc:898
25: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
at /home/zte/TVM/tvm/src/tir/schedule/traced_schedule.cc:615
26: operator()
at /home/zte/TVM/tvm/src/tir/schedule/schedule.cc:248
Traceback (most recent call last): File “/home/zte/TVM/tvm/testcase/pointwise/test_take.py”, line 112, in main() File “/home/zte/TVM/tvm/testcase/pointwise/test_take.py”, line 89, in main lib, params_from_torch = compile_model(mod_from_torch, target, “hh_pipeline”) File “/home/zte/TVM/tvm/testcase/pointwise/test_take.py”, line 45, in compile_model lib = relax.build(mod_from_torch, target=target, relax_pipeline=pipeline) File “/home/zte/TVM/tvm/python/tvm/relax/vm_build.py”, line 253, in build mod = relax_pipeline(mod) File “/home/zte/TVM/tvm/python/tvm/ir/transform.py”, line 238, in call return _ffi_transform_api.RunPass(self, mod) File “tvm/_ffi/_cython/./packed_func.pxi”, line 339, in tvm._ffi._cy3.core.PackedFuncBase.call File “tvm/_ffi/_cython/./packed_func.pxi”, line 270, in tvm._ffi._cy3.core.FuncCall File “tvm/_ffi/_cython/./packed_func.pxi”, line 259, in tvm._ffi._cy3.core.FuncCall3 File “tvm/_ffi/_cython/./base.pxi”, line 185, in tvm._ffi._cy3.core.CHECK_CALL File “/home/zte/TVM/tvm/python/tvm/_ffi/base.py”, line 468, in raise_last_ffi_error raise py_err File “tvm/_ffi/_cython/./packed_func.pxi”, line 56, in tvm._ffi._cy3.core.tvm_callback File “/home/zte/TVM/tvm/python/tvm/contrib/npu/pipeline.py”, line 110, in _pipeline mod = seq(mod) File “/home/zte/TVM/tvm/python/tvm/ir/transform.py”, line 238, in call return _ffi_transform_api.RunPass(self, mod) File “tvm/_ffi/_cython/./packed_func.pxi”, line 339, in tvm._ffi._cy3.core.PackedFuncBase.call File “tvm/_ffi/_cython/./packed_func.pxi”, line 270, in tvm._ffi._cy3.core.FuncCall File “tvm/_ffi/_cython/./packed_func.pxi”, line 259, in tvm._ffi._cy3.core.FuncCall3 File “tvm/_ffi/_cython/./base.pxi”, line 185, in tvm._ffi._cy3.core.CHECK_CALL File “tvm/_ffi/_cython/./packed_func.pxi”, line 56, in tvm._ffi._cy3.core.tvm_callback File “/home/zte/TVM/tvm/python/tvm/ir/transform.py”, line 307, in _pass_func return inst.transform_module(mod, ctx) File “/home/zte/TVM/tvm/python/tvm/dlight/base/transform.py”, line 71, in transform_module sch = _apply_rules(func, target, self.rules, tunable=False) File “/home/zte/TVM/tvm/python/tvm/dlight/base/transform.py”, line 87, in _apply_rules space = rule.apply(func, target, tunable) File “/home/zte/TVM/tvm/python/tvm/dlight/npu/injective.py”, line 93, in apply return self._take_intrin(sch, loops, output_dtype, original_input_shape, original_output_shape, block_rv) File “/home/zte/TVM/tvm/python/tvm/dlight/npu/injective.py”, line 287, in _take_intrin sch.tensorize(loops[0], op_intrin_name) File “/home/zte/TVM/tvm/python/tvm/tir/schedule/_type_checker.py”, line 340, in wrap return func(*args, **kwargs) File “/home/zte/TVM/tvm/python/tvm/tir/schedule/schedule.py”, line 3046, in tensorize _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member File “tvm/_ffi/_cython/./packed_func.pxi”, line 339, in tvm._ffi._cy3.core.PackedFuncBase.call File “tvm/_ffi/_cython/./packed_func.pxi”, line 284, in tvm._ffi._cy3.core.FuncCall File “tvm/_ffi/_cython/./base.pxi”, line 185, in tvm._ffi._cy3.core.CHECK_CALL File “/home/zte/TVM/tvm/python/tvm/_ffi/base.py”, line 468, in raise_last_ffi_error raise py_err tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last): 2: operator() at /home/zte/TVM/tvm/src/tir/schedule/schedule.cc:248 1: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool) at /home/zte/TVM/tvm/src/tir/schedule/traced_schedule.cc:615 0: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool) at /home/zte/TVM/tvm/src/tir/schedule/concrete_schedule.cc:901 ScheduleError: An error occurred in the schedule primitive ‘tensorize’. The IR with diagnostic is:
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module class Module: @T.prim_func(private=True) def main(var_A: T.handle, var_xx: T.handle, var_T_take: T.handle): T.func_attr({“op_pattern”: 8, “tir.noalias”: T.bool(True)}) A = T.match_buffer(var_A, (T.int64(10), T.int64(96))) xx = T.match_buffer(var_xx, (T.int64(5),), “int64”) T_take = T.match_buffer(var_T_take, (T.int64(5), T.int64(96))) with T.block(“root”): T.reads() T.writes() for ax0 in range(T.int64(5)): for ax1 in range(T.int64(96)): with T.block(“T_take”): v0 = T.axis.spatial(T.int64(5), ax0) v1 = T.axis.spatial(T.int64(96), ax1) T.reads(A[xx[v0], v1], xx[v0]) T.writes(T_take[v0, v1]) T_take[v0, v1] = A[xx[v0], v1] Error message: The stmt tir.Block#0 doesn’t match the tensor intrin The pattern attempting to be matched: with T.block(“T_take”, no_realize=True): v0_i = T.axis.spatial(T.int64(5)) v1_i = T.axis.spatial(T.int64(96)) A = T.Buffer((T.int64(10), T.int64(96))) xx = T.Buffer((T.int64(5),), “int64”) T.reads(A[xx[v0_i], v1_i], xx[v0_i]) T_take = T.Buffer((T.int64(5), T.int64(96))) T.writes(T_take[v0_i, v1_i]) T_take[v0_i, v1_i] = A[xx[v0_i], v1_i] Does not match the tensorize description: with T.block(“T_take”, no_realize=True): v0_i = T.axis.spatial(T.int64(5)) v1_i = T.axis.spatial(T.int64(96)) A = T.Buffer((T.int64(10), T.int64(96))) xx = T.Buffer((T.int64(5),), “int64”) T.reads(A[xx[v0_i], v1_i], xx[v0_i]) T_take = T.Buffer((T.int64(5), T.int64(96))) T.writes(T_take[v0_i, v1_i]) T_take[v0_i, v1_i] = A[xx[v0_i], v1_i] CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]=range(min=xx[v0_i], ext=(int64)1)Range(0x55e61839eb60) vs rhs->region[i]=range(min=xx[v0_i], ext=(int64)1)Range(0x55e618375580) BlockNode read buffers regions do not match: op->reads=[A[xx[v0_i], v1_i], xx[v0_i]] vs rhs->reads=[A[xx[v0_i], v1_i], xx[v0_i]]