I’m wrting my op schedule with using storage_align and double_buffer schedule, but these two schedule seems conflict with each. below is my test code
import tvm
from tvm import te
M = 63
N = 63
A = te.placeholder((M, N), dtype="float32", name="A")
B = te.compute((M, N), lambda *i: A(*i), name="B", )
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
AS = s.cache_read(A, "shared", [B])
cx, ci = B.op.axis
s[B].bind(cx, tx)
ax, ai = AS.op.axis
s[AS].compute_at(s[B], cx)
s[AS].bind(ax, tx)
s[AS].storage_align(ax, 63, 64)
# s[AS].double_buffer()
print(tvm.lower(s, [A, B]))
if i don’t use the double buffer schedule, it seem ok. and get the correct tensor expression.
@main = primfn(A_1: handle, B_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {B: Buffer(B_2: Pointer(float32), float32, [3969], []),
A: Buffer(A_2: Pointer(float32), float32, [3969], [])}
buffer_map = {A_1: A, B_1: B} {
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 63;
allocate(A.shared: Pointer(shared float32), float32, [4032]), storage_scope = shared {
for (ax1: int32, 0, 63) {
A.shared_1: Buffer(A.shared, float32, [4032], [], scope="shared")[((threadIdx.x*64) + ax1)] = A[((threadIdx.x*63) + ax1)]
}
for (i1: int32, 0, 63) {
B[((threadIdx.x*63) + i1)] = A.shared_1[((threadIdx.x*64) + i1)]
}
}
}
Then I try to uncomment s[AS].double_buffer
, it gets an strange error.
tvm._ffi.base.TVMError: Traceback (most recent call last):
33: TVMFuncCall
32: _ZN3tvm7runtime13PackedFun
31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)>::AssignTypedLambda<tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}>(tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
30: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
29: tvm::LowerWithPassList(tvm::IRModule, tvm::runtime::Array<tvm::transform::Pass, void>)
28: tvm::transform::Pass::operator()(tvm::IRModule) const
27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
26: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
25: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
24: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
23: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform14StorageFlattenEibEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
22: tvm::tir::StorageFlatten(tvm::tir::PrimFunc, int, bool)
21: tvm::transform::Pass::operator()(tvm::IRModule) const
20: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
19: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
17: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
16: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_16StorageFlattener4PassEibEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
15: tvm::tir::StorageFlattener::Pass(int, bool)::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext) const
14: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
13: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7r
12: tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*)
11: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
10: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
9: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7r
8: tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::BufferRealizeNode const*)
7: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
6: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runt
5: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
4: void tvm::runtime::Array<tvm::tir::Stmt, void>::MutateByApply<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}>(tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
3: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
2: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7r
1: tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*)
0: tvm::tir::StorageFlattener::GetBufferEntry(tvm::tir::Buffer)
File "/home/chenugray/source_code/tvm/src/tir/transforms/storage_flatten.cc", line 1545
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (it != buf_map_.end()) is false: Cannot find allocated buffer for buffer(A.shared, 0x150bb10)