[unity] Performance degradation caused by redundant buffer indexing

A long time ago, I used TVM with 0.9.dev0, but now I have upgraded to Unity version, only to find max_pool2d’s schedule performance degradation caused by redundant buffer indexing. Below is my detailed description of the matter in https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/pooling.py I use tvm.lower() to show the schedule tir,I found,in latest unity branch(now called master branch),the buffer indexing of local memory was too redundant,and it is not locally spatially continuous, resulting in poor memory access performance, see the p0_1 buffer’s indexing:

   @I.ir_module
    class Module:
        @T.prim_func
        def nn_max_pool2d(p0: T.Buffer((32, 20, 20, 384), "int8")):
            T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
            pool_max = T.allocate([4915200], "int8", "global")
            blockIdx_x = T.launch_thread("blockIdx.x", 4800)
            pool_max_local = T.allocate([1], "int8", "local")
            threadIdx_x = T.launch_thread("threadIdx.x", 1024)
            pool_max_local_1 = T.Buffer((1,), "int8", data=pool_max_local, scope="local", align=1)
            pool_max_local_1[0] = T.int8(-128)
            p0_1 = T.Buffer((4915200,), "int8", data=p0.data)
            pool_max_local_1[0] = T.max(pool_max_local_1[0], 
              T.if_then_else(30 <= blockIdx_x % 150 * 2 + threadIdx_x // 512 and 6 <= (blockIdx_x * 8 + threadIdx_x // 128) % 60, 
                    p0_1[blockIdx_x // 150 * 153600 + 
                          (blockIdx_x % 150 * 2 + threadIdx_x // 512) // 15 * 7680 + 
                          (blockIdx_x * 8 + threadIdx_x // 128) % 60 // 3 * 384 + 
                          (blockIdx_x * 256 + threadIdx_x) % 384 - 16128], T.int8(-128)))
                          ... 
            pool_max_1 = T.Buffer((4915200,), "int8", data=pool_max)
            pool_max_1[blockIdx_x * 1024 + threadIdx_x] = pool_max_local_1[0]

and the same code which runs in TVM with 0.9.dev0, will have good buffer indexing format likes follows placeholder_2[(((blockIdx.x*1024) + threadIdx.x) - 16128)]:

 @nn_max_pool2d = primfn(placeholder_1: handle) -> ()
      attr = {"from_legacy_te_schedule": True, "global_symbol": "nn_max_pool2d", "tir.noalias": True}
      buffers = {placeholder: Buffer(placeholder_2: Pointer(int8), int8, [32, 20, 20, 384], [])}
      buffer_map = {placeholder_1: placeholder} {
      allocate(tensor: Pointer(global int8), int8, [4915200]), storage_scope = global;
      attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 4800;
      allocate(tensor.local: Pointer(local int8), int8, [1]), storage_scope = local;
      attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 1024 {
        tensor.local[0] = -128i8
        tensor.local[0] = max((int8*)tensor.local[0], 
          @tir.if_then_else(((15360 <= floormod(((blockIdx.x*1024) + threadIdx.x), 153600)) && (768 <= floormod(((blockIdx.x*1024) + threadIdx.x), 7680))), 
            (int8*)placeholder_2[(((blockIdx.x*1024) + threadIdx.x) - 16128)], -128i8, dtype=int8))
            ... 
        tensor[((blockIdx.x*1024) + threadIdx.x)] = (int8*)tensor.local[0]
      }
    }

Can anyone provide assistance? I would be extremely grateful!

and the following is my testing code: import tvm from tvm import relay from tvm.contrib import graph_executor from tvm.relay.testing import run_infer_type from tvm.relay import testing from tvm.relay.import_model import import_model_to_igie import onnx import numpy as np

target = tvm.target.cuda()
device = tvm.cuda(0)

data_shape = [32, 20, 20, 384]
data_type = r'int8'
data = relay.var("data", shape=data_shape, dtype=data_type)
out = relay.nn.max_pool2d(data,
                        pool_size=(5, 5),  # pool_size=(1, 1),   
                        padding=(2, 2),
                        layout="NHWC",
                        out_layout="",
                        ceil_mode=False)

free_vars = relay.analysis.free_vars(out)
func = relay.Function(free_vars, out)
mod, params = testing.create_workload(func)
# import ipdb as pdb; pdb.set_trace()
# mod.show()
lib = tvm.relay.build(mod, target=target, params=params)

note, I wrote some tvm.lower() code in /path/tvm/python/tvm/topi/cuda/pooling.py, like this: def schedule_pool(outs, layout):

    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _schedule(PaddedInput, Pool):
        if isinstance(PaddedInput.op, tvm.te.ComputeOp):
            s[PaddedInput].compute_inline()
        num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
        if Pool.op in s.outputs:
            Out = Pool
            OL = s.cache_write(Pool, "local")
        else:
            Out = outs[0].op.output(0)
            s[Pool].set_scope("local")
        fused = s[Out].fuse(*s[Out].op.axis)
        bx, tx = s[Out].split(fused, factor=num_thread)
        s[Out].bind(bx, te.thread_axis("blockIdx.x"))
        s[Out].bind(tx, te.thread_axis("threadIdx.x"))
        if Pool.op in s.outputs:
            s[OL].compute_at(s[Out], tx)
        else:
            s[Pool].compute_at(s[Out], tx)

        aa = s[PaddedInput].op.input_tensors[0]
        print(tvm.lower(s, [aa], simple_mode=True, name="nn_max_pool2d").show())

    scheduled_ops = []