[Questions][meta_schedule] Some questions about nested block when tune_tir

Hello, I’m trying to implement an matmul using tir, and when I make a nested block, some schedule_rule and postprocs look weird. Here is my demo code:

@I.ir_module
class Module:
    @T.prim_func
    def main(
        Q: T.Buffer((B, C, H, W), dtype),
        K: T.Buffer((B, C, H, W), dtype),
        out: T.Buffer((B, C, W, W), dtype),
    ):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        compute = T.alloc_buffer((W, W), 'float32', scope='shared')
        for b, c in T.grid(B, C):
            with T.block('total'):
                for i, j, k in T.grid(W, W, H):
                    with T.block("compute_qk"):
                        nb, nc = T.axis.remap("SS", [b, c])
                        ni, nj, nk = T.axis.remap("SSR", [i, j, k])
                        with T.init():
                            compute[ni, nj] = T.float32(0)
                        compute[ni, nj] += T.Cast('float32', Q[nb, nc, nk, nj]) * T.Cast('float32', K[nb, nc, nk, ni])

Q1: AutoBind act separately for each block, that causes outer block and inner block may have different thread_extent. does AutoBind consider support nested block? Or is there any way to flatten the block or enhance AutoBind?

tvm/src/tir/transforms/unify_thread_binding.cc", line 100
ValueError: Check failed: (ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) is false: All loops that are bound to `blockIdx.x` should have the same extent. However, there are two loops with extent 1 and 64, which are not equal

Q2: when I use match_buffer with indices to reference a region of given buffer and remove high dimension indices, I got this error: “TVMError: Check failed: (analyzer_.CanProve(lhs == rhs)) is false: The buffer match constraint for Q_re.elem_offset unmet: 0==blockIdx_x * 16384 + threadIdx_y * 8192.”

Here elem_offset is 0, and buffer_start_indices is calculated using nb and nc

@I.ir_module
class Module:
    @T.prim_func
    def main(
        Q: T.Buffer((B, C, H, W), dtype),
        K: T.Buffer((B, C, H, W), dtype),
        out: T.Buffer((B, C, W, W), dtype),
    ):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        compute = T.alloc_buffer((W, W), 'float32', scope='shared')
        for b, c in T.grid(B, C):
            with T.block("total"):
                nb, nc = T.axis.remap("SS", [b, c])
                Q_re = T.match_buffer(Q[nb, nc, :, :], (H, W), dtype)
                K_re = T.match_buffer(K[nb, nc, :, :], (H, W), dtype)
                for i, j, k in T.grid(W, W, H):
                    with T.block("compute_qk"):
                        ni, nj, nk = T.axis.remap("SSR", [i, j, k])
                        with T.init():
                            compute[ni, nj] = T.float32(0)
                        compute[ni, nj] += T.Cast('float32', Q_re[nk, nj]) * T.Cast('float32', K_re[nk, ni])