Problem with storage_align and double_buffer

I’m wrting my op schedule with using storage_align and double_buffer schedule, but these two schedule seems conflict with each. below is my test code

import tvm
from tvm import te
M = 63
N = 63
A = te.placeholder((M, N), dtype="float32", name="A")
B = te.compute((M, N), lambda *i: A(*i), name="B", )
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
AS = s.cache_read(A, "shared", [B])
cx, ci = B.op.axis
s[B].bind(cx, tx)
ax, ai = AS.op.axis
s[AS].compute_at(s[B], cx)
s[AS].bind(ax, tx)
s[AS].storage_align(ax, 63, 64)
# s[AS].double_buffer()
print(tvm.lower(s, [A, B]))

if i don’t use the double buffer schedule, it seem ok. and get the correct tensor expression.

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [3969], []),
             A: Buffer(A_2: Pointer(float32), float32, [3969], [])}
  buffer_map = {A_1: A, B_1: B} {
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 63;
  allocate(A.shared: Pointer(shared float32), float32, [4032]), storage_scope = shared {
    for (ax1: int32, 0, 63) {
      A.shared_1: Buffer(A.shared, float32, [4032], [], scope="shared")[((threadIdx.x*64) + ax1)] = A[((threadIdx.x*63) + ax1)]
    }
    for (i1: int32, 0, 63) {
      B[((threadIdx.x*63) + i1)] = A.shared_1[((threadIdx.x*64) + i1)]
    }
  }
}

Then I try to uncomment s[AS].double_buffer, it gets an strange error.

tvm._ffi.base.TVMError: Traceback (most recent call last):
  33: TVMFuncCall
  32: _ZN3tvm7runtime13PackedFun
  31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)>::AssignTypedLambda<tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}>(tvm::{lambda(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, tvm::runtime::String const&, tvm::runtime::Map<tvm::te::Tensor, tvm::tir::Buffer, void, void> const&, bool)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  30: tvm::LowerSchedule(tvm::te::Schedule, tvm::runtime::Array<tvm::runtime::ObjectRef, void> const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::unordered_map<tvm::te::Tensor, tvm::tir::Buffer, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::tir::Buffer> > > const&, bool)
  29: tvm::LowerWithPassList(tvm::IRModule, tvm::runtime::Array<tvm::transform::Pass, void>)
  28: tvm::transform::Pass::operator()(tvm::IRModule) const
  27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  25: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  24: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  23: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform14StorageFlattenEibEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
  22: tvm::tir::StorageFlatten(tvm::tir::PrimFunc, int, bool)
  21: tvm::transform::Pass::operator()(tvm::IRModule) const
  20: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  16: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_16StorageFlattener4PassEibEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
  15: tvm::tir::StorageFlattener::Pass(int, bool)::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext) const
  14: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  13: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7r
  12: tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*)
  11: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  10: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  9: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7r
  8: tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::BufferRealizeNode const*)
  7: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  6: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runt
  5: tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
  4: void tvm::runtime::Array<tvm::tir::Stmt, void>::MutateByApply<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}>(tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  3: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  2: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7r
  1: tvm::tir::StorageFlattener::VisitStmt_(tvm::tir::AttrStmtNode const*)
  0: tvm::tir::StorageFlattener::GetBufferEntry(tvm::tir::Buffer)
  File "/home/chenugray/source_code/tvm/src/tir/transforms/storage_flatten.cc", line 1545
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (it != buf_map_.end()) is false: Cannot find allocated buffer for buffer(A.shared, 0x150bb10)
1 Like

Using both storage_align and double_buffer results in "Cannot find allocated buffer for buffer(A.shared, 0x150bb10) error in storage_flatten.cc file“. In the above example, the global memory data of A is loaded into shared memory, and the storage_align schedule is used to solve the bank conflict problem. You can also use the tensorize schedule to load A’s data into shared memory and manually resolve the bank conflict problem. This way tensorize and double_buffer can be used at the same time.

import tvm
from tvm import te

def intrin_load_matrix_to_slb():
    output_shape = (16, 64)
    strides_src = [64, 1]
    strides_dst = [64, 1]

    A = te.placeholder(output_shape, name="A", dtype="float32")
    C = te.compute(output_shape, lambda *i: A(*i), name="C")

    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="global", strides=strides_src, data_alignment=64, offset_factor=1)
    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="shared",  strides=strides_dst, data_alignment=64, offset_factor=1)

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()

        BA = ins[0]
        BC = outs[0]

        tx = te.thread_axis("threadIdx.x")
        ib.scope_attr(tx, "thread_extent", 64)
        index = tx // 1

        for outer in range(0, 16):
                ib.emit(BC.vstore([outer, index], BA.vload([outer, index], "float32")))

        return ib.get()
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

M = 64
N = 64
A = te.placeholder((M, N), dtype="float32", name="A")
B = te.compute((M, N), lambda *i: A(*i), name="B", )
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
AS = s.cache_read(A, "shared", [B])
cx, ci = B.op.axis
cxo, cxi = s[B].split(cx, factor=16)
s[B].reorder(cxo, cxi, ci)
s[B].bind(ci, tx)

s[AS].compute_at(s[B], cxo)
ax, ai = AS.op.axis
# s[AS].storage_align(ax, 63, 64)
s[AS].tensorize(ax, intrin_load_matrix_to_slb())
s[AS].double_buffer()
print(tvm.lower(s, [A, B]))

Output:

Using the tensorize schedule primitive, and not using the double_buffer schedule primitive.

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [64, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [64, 64], [])}
  buffer_map = {A_1: A, B_1: B} {
  allocate(A.shared: Pointer(shared float32), float32, [1024]), storage_scope = shared;
  for (i0.outer: int32, 0, 4) {
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
      A.shared[threadIdx.x] = (float32*)A_2[((i0.outer*1024) + threadIdx.x)]
      A.shared[(threadIdx.x + 64)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 64)]
      A.shared[(threadIdx.x + 128)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 128)]
      A.shared[(threadIdx.x + 192)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 192)]
      A.shared[(threadIdx.x + 256)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 256)]
      A.shared[(threadIdx.x + 320)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 320)]
      A.shared[(threadIdx.x + 384)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 384)]
      A.shared[(threadIdx.x + 448)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 448)]
      A.shared[(threadIdx.x + 512)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 512)]
      A.shared[(threadIdx.x + 576)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 576)]
      A.shared[(threadIdx.x + 640)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 640)]
      A.shared[(threadIdx.x + 704)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 704)]
      A.shared[(threadIdx.x + 768)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 768)]
      A.shared[(threadIdx.x + 832)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 832)]
      A.shared[(threadIdx.x + 896)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 896)]
      A.shared[(threadIdx.x + 960)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 960)]
    }
    for (i0.inner: int32, 0, 16) {
      attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
      B_2[(((i0.outer*1024) + (i0.inner*64)) + threadIdx.x_1)] = (float32*)A.shared[((i0.inner*64) + threadIdx.x_1)]
    }
  }
}

Using the tensorize and double_buffer schedule primitives.

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [64, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [64, 64], [])}
  buffer_map = {A_1: A, B_1: B} {
  allocate(A.shared: Pointer(shared float32), float32, [2048]), storage_scope = shared {
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
      A.shared[threadIdx.x] = (float32*)A_2[threadIdx.x]
      A.shared[(threadIdx.x + 64)] = (float32*)A_2[(threadIdx.x + 64)]
      A.shared[(threadIdx.x + 128)] = (float32*)A_2[(threadIdx.x + 128)]
      A.shared[(threadIdx.x + 192)] = (float32*)A_2[(threadIdx.x + 192)]
      A.shared[(threadIdx.x + 256)] = (float32*)A_2[(threadIdx.x + 256)]
      A.shared[(threadIdx.x + 320)] = (float32*)A_2[(threadIdx.x + 320)]
      A.shared[(threadIdx.x + 384)] = (float32*)A_2[(threadIdx.x + 384)]
      A.shared[(threadIdx.x + 448)] = (float32*)A_2[(threadIdx.x + 448)]
      A.shared[(threadIdx.x + 512)] = (float32*)A_2[(threadIdx.x + 512)]
      A.shared[(threadIdx.x + 576)] = (float32*)A_2[(threadIdx.x + 576)]
      A.shared[(threadIdx.x + 640)] = (float32*)A_2[(threadIdx.x + 640)]
      A.shared[(threadIdx.x + 704)] = (float32*)A_2[(threadIdx.x + 704)]
      A.shared[(threadIdx.x + 768)] = (float32*)A_2[(threadIdx.x + 768)]
      A.shared[(threadIdx.x + 832)] = (float32*)A_2[(threadIdx.x + 832)]
      A.shared[(threadIdx.x + 896)] = (float32*)A_2[(threadIdx.x + 896)]
      A.shared[(threadIdx.x + 960)] = (float32*)A_2[(threadIdx.x + 960)]
    }
    for (i0.outer.outer: int32, 0, 3) {
      attr [A.shared] "double_buffer_write" = 1;
      attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
        A.shared[((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1024)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 64)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1088)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 128)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1152)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 192)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1216)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 256)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1280)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 320)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1344)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 384)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1408)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 448)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1472)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 512)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1536)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 576)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1600)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 640)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1664)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 704)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1728)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 768)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1792)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 832)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1856)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 896)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1920)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 960)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1984)]
      }
      for (i0.inner: int32, 0, 16) {
        attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
        B_2[(((i0.outer.outer*1024) + (i0.inner*64)) + threadIdx.x_1)] = (float32*)A.shared[(((floormod(i0.outer.outer, 2)*1024) + (i0.inner*64)) + threadIdx.x_1)]
      }
    }
    for (i0.inner_1: int32, 0, 16) {
      attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
      B_2[(((i0.inner_1*64) + threadIdx.x_1) + 3072)] = (float32*)A.shared[(((i0.inner_1*64) + threadIdx.x_1) + 1024)]
    }
  }
}