A long time ago, I used TVM with 0.9.dev0, but now I have upgraded to Unity version, only to find max_pool2d’s schedule performance degradation caused by redundant buffer indexing. Below is my detailed description of the matter in https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/pooling.py I use tvm.lower() to show the schedule tir,I found,in latest unity branch(now called master branch),the buffer indexing of local memory was too redundant,and it is not locally spatially continuous, resulting in poor memory access performance, see the p0_1 buffer’s indexing:
@I.ir_module
class Module:
@T.prim_func
def nn_max_pool2d(p0: T.Buffer((32, 20, 20, 384), "int8")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
pool_max = T.allocate([4915200], "int8", "global")
blockIdx_x = T.launch_thread("blockIdx.x", 4800)
pool_max_local = T.allocate([1], "int8", "local")
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
pool_max_local_1 = T.Buffer((1,), "int8", data=pool_max_local, scope="local", align=1)
pool_max_local_1[0] = T.int8(-128)
p0_1 = T.Buffer((4915200,), "int8", data=p0.data)
pool_max_local_1[0] = T.max(pool_max_local_1[0],
T.if_then_else(30 <= blockIdx_x % 150 * 2 + threadIdx_x // 512 and 6 <= (blockIdx_x * 8 + threadIdx_x // 128) % 60,
p0_1[blockIdx_x // 150 * 153600 +
(blockIdx_x % 150 * 2 + threadIdx_x // 512) // 15 * 7680 +
(blockIdx_x * 8 + threadIdx_x // 128) % 60 // 3 * 384 +
(blockIdx_x * 256 + threadIdx_x) % 384 - 16128], T.int8(-128)))
...
pool_max_1 = T.Buffer((4915200,), "int8", data=pool_max)
pool_max_1[blockIdx_x * 1024 + threadIdx_x] = pool_max_local_1[0]
and the same code which runs in TVM with 0.9.dev0, will have good buffer indexing format likes follows placeholder_2[(((blockIdx.x*1024) + threadIdx.x) - 16128)]:
@nn_max_pool2d = primfn(placeholder_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "nn_max_pool2d", "tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(int8), int8, [32, 20, 20, 384], [])}
buffer_map = {placeholder_1: placeholder} {
allocate(tensor: Pointer(global int8), int8, [4915200]), storage_scope = global;
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 4800;
allocate(tensor.local: Pointer(local int8), int8, [1]), storage_scope = local;
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 {
tensor.local[0] = -128i8
tensor.local[0] = max((int8*)tensor.local[0],
@tir.if_then_else(((15360 <= floormod(((blockIdx.x*1024) + threadIdx.x), 153600)) && (768 <= floormod(((blockIdx.x*1024) + threadIdx.x), 7680))),
(int8*)placeholder_2[(((blockIdx.x*1024) + threadIdx.x) - 16128)], -128i8, dtype=int8))
...
tensor[((blockIdx.x*1024) + threadIdx.x)] = (int8*)tensor.local[0]
}
}
Can anyone provide assistance? I would be extremely grateful!