Using both storage_align and double_buffer results in "Cannot find allocated buffer for buffer(A.shared, 0x150bb10) error in storage_flatten.cc file“. In the above example, the global memory data of A is loaded into shared memory, and the storage_align schedule is used to solve the bank conflict problem. You can also use the tensorize schedule to load A’s data into shared memory and manually resolve the bank conflict problem. This way tensorize and double_buffer can be used at the same time.
import tvm
from tvm import te
def intrin_load_matrix_to_slb():
output_shape = (16, 64)
strides_src = [64, 1]
strides_dst = [64, 1]
A = te.placeholder(output_shape, name="A", dtype="float32")
C = te.compute(output_shape, lambda *i: A(*i), name="C")
BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="global", strides=strides_src, data_alignment=64, offset_factor=1)
BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="shared", strides=strides_dst, data_alignment=64, offset_factor=1)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
BA = ins[0]
BC = outs[0]
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", 64)
index = tx // 1
for outer in range(0, 16):
ib.emit(BC.vstore([outer, index], BA.vload([outer, index], "float32")))
return ib.get()
return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
M = 64
N = 64
A = te.placeholder((M, N), dtype="float32", name="A")
B = te.compute((M, N), lambda *i: A(*i), name="B", )
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
AS = s.cache_read(A, "shared", [B])
cx, ci = B.op.axis
cxo, cxi = s[B].split(cx, factor=16)
s[B].reorder(cxo, cxi, ci)
s[B].bind(ci, tx)
s[AS].compute_at(s[B], cxo)
ax, ai = AS.op.axis
# s[AS].storage_align(ax, 63, 64)
s[AS].tensorize(ax, intrin_load_matrix_to_slb())
s[AS].double_buffer()
print(tvm.lower(s, [A, B]))
Output:
Using the tensorize schedule primitive, and not using the double_buffer schedule primitive.
@main = primfn(A_1: handle, B_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {B: Buffer(B_2: Pointer(float32), float32, [64, 64], []),
A: Buffer(A_2: Pointer(float32), float32, [64, 64], [])}
buffer_map = {A_1: A, B_1: B} {
allocate(A.shared: Pointer(shared float32), float32, [1024]), storage_scope = shared;
for (i0.outer: int32, 0, 4) {
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
A.shared[threadIdx.x] = (float32*)A_2[((i0.outer*1024) + threadIdx.x)]
A.shared[(threadIdx.x + 64)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 64)]
A.shared[(threadIdx.x + 128)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 128)]
A.shared[(threadIdx.x + 192)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 192)]
A.shared[(threadIdx.x + 256)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 256)]
A.shared[(threadIdx.x + 320)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 320)]
A.shared[(threadIdx.x + 384)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 384)]
A.shared[(threadIdx.x + 448)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 448)]
A.shared[(threadIdx.x + 512)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 512)]
A.shared[(threadIdx.x + 576)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 576)]
A.shared[(threadIdx.x + 640)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 640)]
A.shared[(threadIdx.x + 704)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 704)]
A.shared[(threadIdx.x + 768)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 768)]
A.shared[(threadIdx.x + 832)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 832)]
A.shared[(threadIdx.x + 896)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 896)]
A.shared[(threadIdx.x + 960)] = (float32*)A_2[(((i0.outer*1024) + threadIdx.x) + 960)]
}
for (i0.inner: int32, 0, 16) {
attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
B_2[(((i0.outer*1024) + (i0.inner*64)) + threadIdx.x_1)] = (float32*)A.shared[((i0.inner*64) + threadIdx.x_1)]
}
}
}
Using the tensorize and double_buffer schedule primitives.
@main = primfn(A_1: handle, B_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {B: Buffer(B_2: Pointer(float32), float32, [64, 64], []),
A: Buffer(A_2: Pointer(float32), float32, [64, 64], [])}
buffer_map = {A_1: A, B_1: B} {
allocate(A.shared: Pointer(shared float32), float32, [2048]), storage_scope = shared {
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
A.shared[threadIdx.x] = (float32*)A_2[threadIdx.x]
A.shared[(threadIdx.x + 64)] = (float32*)A_2[(threadIdx.x + 64)]
A.shared[(threadIdx.x + 128)] = (float32*)A_2[(threadIdx.x + 128)]
A.shared[(threadIdx.x + 192)] = (float32*)A_2[(threadIdx.x + 192)]
A.shared[(threadIdx.x + 256)] = (float32*)A_2[(threadIdx.x + 256)]
A.shared[(threadIdx.x + 320)] = (float32*)A_2[(threadIdx.x + 320)]
A.shared[(threadIdx.x + 384)] = (float32*)A_2[(threadIdx.x + 384)]
A.shared[(threadIdx.x + 448)] = (float32*)A_2[(threadIdx.x + 448)]
A.shared[(threadIdx.x + 512)] = (float32*)A_2[(threadIdx.x + 512)]
A.shared[(threadIdx.x + 576)] = (float32*)A_2[(threadIdx.x + 576)]
A.shared[(threadIdx.x + 640)] = (float32*)A_2[(threadIdx.x + 640)]
A.shared[(threadIdx.x + 704)] = (float32*)A_2[(threadIdx.x + 704)]
A.shared[(threadIdx.x + 768)] = (float32*)A_2[(threadIdx.x + 768)]
A.shared[(threadIdx.x + 832)] = (float32*)A_2[(threadIdx.x + 832)]
A.shared[(threadIdx.x + 896)] = (float32*)A_2[(threadIdx.x + 896)]
A.shared[(threadIdx.x + 960)] = (float32*)A_2[(threadIdx.x + 960)]
}
for (i0.outer.outer: int32, 0, 3) {
attr [A.shared] "double_buffer_write" = 1;
attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64 {
A.shared[((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1024)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 64)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1088)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 128)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1152)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 192)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1216)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 256)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1280)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 320)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1344)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 384)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1408)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 448)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1472)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 512)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1536)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 576)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1600)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 640)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1664)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 704)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1728)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 768)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1792)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 832)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1856)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 896)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1920)]
A.shared[(((floormod((i0.outer.outer + 1), 2)*1024) + threadIdx.x) + 960)] = (float32*)A_2[(((i0.outer.outer*1024) + threadIdx.x) + 1984)]
}
for (i0.inner: int32, 0, 16) {
attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
B_2[(((i0.outer.outer*1024) + (i0.inner*64)) + threadIdx.x_1)] = (float32*)A.shared[(((floormod(i0.outer.outer, 2)*1024) + (i0.inner*64)) + threadIdx.x_1)]
}
}
for (i0.inner_1: int32, 0, 16) {
attr [IterVar(threadIdx.x_1, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
B_2[(((i0.inner_1*64) + threadIdx.x_1) + 3072)] = (float32*)A.shared[(((i0.inner_1*64) + threadIdx.x_1) + 1024)]
}
}
}