Problem with storage_align and double_buffer

Using both storage_align and double_buffer results in "Cannot find allocated buffer for buffer(A.shared, 0x150bb10) error in storage_flatten.cc file“. In the above example, the global memory data of A is loaded into shared memory, and the storage_align schedule is used to solve the bank conflict problem. You can also use the tensorize schedule to load A’s data into shared memory and manually resolve the bank conflict problem. This way tensorize and double_buffer can be used at the same time.

import tvm
from tvm import te

def intrin_load_matrix_to_slb():
    output_shape = (16, 64)
    strides_src = [64, 1]
    strides_dst = [64, 1]

    A = te.placeholder(output_shape, name="A", dtype="float32")
    C = te.compute(output_shape, lambda *i: A(*i), name="C")

    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="global", strides=strides_src, data_alignment=64, offset_factor=1)
    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="shared",  strides=strides_dst, data_alignment=64, offset_factor=1)

    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()

        BA = ins[0]
        BC = outs[0]

        tx = te.thread_axis("threadIdx.x")
        ib.scope_attr(tx, "thread_extent", 64)
        index = tx // 1

        for outer in range(0, 16):
                ib.emit(BC.vstore([outer, index], BA.vload([outer, index], "float32")))

        return ib.get()
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

M = 64
N = 64
A = te.placeholder((M, N), dtype="float32", name="A")
B = te.compute((M, N), lambda *i: A(*i), name="B", )
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
AS = s.cache_read(A, "shared", [B])
cx, ci = B.op.axis
cxo, cxi = s[B].split(cx, factor=16)
s[B].reorder(cxo, cxi, ci)
s[B].bind(ci, tx)

s[AS].compute_at(s[B], cxo)
ax, ai = AS.op.axis
# s[AS].storage_align(ax, 63, 64)
s[AS].tensorize(ax, intrin_load_matrix_to_slb())
s[AS].double_buffer()
print(tvm.lower(s, [A, B]))

Output:

Using the tensorize schedule primitive, and not using the double_buffer schedule primitive.

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [64, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [64, 64], [])}
  buffer_map = {A_1: A, B_1: B} {
  allocate(A.shared: Pointer(shared float32), float32, [1024]), storage_scope = shared;
  for (i0.outer: int32, 0, 4) {
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
      A.shared[threadIdx.x] = (float32*)A_2[((i0.outer*1024) + threadIdx.x)]
      A.shared[(threadIdx.x + 64)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 64)]
      A.shared[(threadIdx.x + 128)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 128)]
      A.shared[(threadIdx.x + 192)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 192)]
      A.shared[(threadIdx.x + 256)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 256)]
      A.shared[(threadIdx.x + 320)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 320)]
      A.shared[(threadIdx.x + 384)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 384)]
      A.shared[(threadIdx.x + 448)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 448)]
      A.shared[(threadIdx.x + 512)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 512)]
      A.shared[(threadIdx.x + 576)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 576)]
      A.shared[(threadIdx.x + 640)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 640)]
      A.shared[(threadIdx.x + 704)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 704)]
      A.shared[(threadIdx.x + 768)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 768)]
      A.shared[(threadIdx.x + 832)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 832)]
      A.shared[(threadIdx.x + 896)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 896)]
      A.shared[(threadIdx.x + 960)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 960)]
    }
    for (i0.inner: int32, 0, 16) {
      attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
      B_2[(((i0.outer*1024) + (i0.inner*64)) + threadIdx.x_1)] = (float32*)A.shared[((i0.inner*64) + threadIdx.x_1)]
    }
  }
}

Using the tensorize and double_buffer schedule primitives.

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {B: Buffer(B_2: Pointer(float32), float32, [64, 64], []),
             A: Buffer(A_2: Pointer(float32), float32, [64, 64], [])}
  buffer_map = {A_1: A, B_1: B} {
  allocate(A.shared: Pointer(shared float32), float32, [2048]), storage_scope = shared {
    attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
      A.shared[threadIdx.x] = (float32*)A_2[threadIdx.x]
      A.shared[(threadIdx.x + 64)] = (float32*)A_2[(threadIdx.x + 64)]
      A.shared[(threadIdx.x + 128)] = (float32*)A_2[(threadIdx.x + 128)]
      A.shared[(threadIdx.x + 192)] = (float32*)A_2[(threadIdx.x + 192)]
      A.shared[(threadIdx.x + 256)] = (float32*)A_2[(threadIdx.x + 256)]
      A.shared[(threadIdx.x + 320)] = (float32*)A_2[(threadIdx.x + 320)]
      A.shared[(threadIdx.x + 384)] = (float32*)A_2[(threadIdx.x + 384)]
      A.shared[(threadIdx.x + 448)] = (float32*)A_2[(threadIdx.x + 448)]
      A.shared[(threadIdx.x + 512)] = (float32*)A_2[(threadIdx.x + 512)]
      A.shared[(threadIdx.x + 576)] = (float32*)A_2[(threadIdx.x + 576)]
      A.shared[(threadIdx.x + 640)] = (float32*)A_2[(threadIdx.x + 640)]
      A.shared[(threadIdx.x + 704)] = (float32*)A_2[(threadIdx.x + 704)]
      A.shared[(threadIdx.x + 768)] = (float32*)A_2[(threadIdx.x + 768)]
      A.shared[(threadIdx.x + 832)] = (float32*)A_2[(threadIdx.x + 832)]
      A.shared[(threadIdx.x + 896)] = (float32*)A_2[(threadIdx.x + 896)]
      A.shared[(threadIdx.x + 960)] = (float32*)A_2[(threadIdx.x + 960)]
    }
    for (i0.outer.outer: int32, 0, 3) {
      attr [A.shared] "double_buffer_write" = 1;
      attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
        A.shared[((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1024)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 64)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1088)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 128)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1152)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 192)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1216)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 256)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1280)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 320)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1344)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 384)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1408)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 448)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1472)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 512)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1536)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 576)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1600)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 640)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1664)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 704)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1728)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 768)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1792)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 832)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1856)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 896)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1920)]
        A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 960)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1984)]
      }
      for (i0.inner: int32, 0, 16) {
        attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
        B_2[(((i0.outer.outer*1024) + (i0.inner*64)) + threadIdx.x_1)] = (float32*)A.shared[(((floormod(i0.outer.outer, 2)*1024) + (i0.inner*64)) + threadIdx.x_1)]
      }
    }
    for (i0.inner_1: int32, 0, 16) {
      attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
      B_2[(((i0.inner_1*64) + threadIdx.x_1) + 3072)] = (float32*)A.shared[(((i0.inner_1*64) + threadIdx.x_1) + 1024)]
    }
  }
}