Cross thread reduce cannot swap with normal data axis

I want to bind reduce axis to blockIdx, but I got error(Check failed: thread_red == 0 (1 vs. 0) : Cross thread reduce cannot swap with normal data axis) when buiding the schedule.

I modified the schedule_batch_matmul_tensorcore function in python/tvm/topi/cuda/batch_matmul_tensorcore.py.

@autotvm.register_topi_schedule(“batch_matmul_tensorcore.cuda”) def schedule_batch_matmul_tensorcore(cfg, outs): “”"Schedule for batch_matmul operator using Tensorcore

Parameters
----------
outs: Array of Tensor
      The computation graph description of batch_matmul
      in the format of an array of tensors.

Returns
-------
s: Schedule
    The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule(cfg, s, C):
    A, B = s[C].op.input_tensors
    batch, m_dim, k_dim = get_const_tuple(A.shape)
    batch, n_dim, k_dim = get_const_tuple(B.shape)
    out_dtype = C.dtype
    # inline astype fp16
    s[A].compute_inline()
    s[B].compute_inline()

    # Explicit memory access
    AS = s.cache_read(A, "shared", [C])
    BS = s.cache_read(B, "shared", [C])
    AF = s.cache_read(AS, "wmma.matrix_a", [C])
    BF = s.cache_read(BS, "wmma.matrix_b", [C])
    CF = s.cache_write(C, "wmma.accumulator")
    CS = s.cache_read(CF, "shared", [C])

    # fallback support
    target = tvm.target.Target.current()
    if cfg.is_fallback:
        ref_log = autotvm.tophub.load_reference_log(
            target.kind.name, target.model, "batch_matmul_tensorcore.cuda"
        )
        cfg.fallback_with_reference_log(ref_log)

    # Deal with op fusion, such as bias/relu and slice after padding
    if C.op not in s.outputs and "injective" in s.outputs[0].tag:
        s[C].compute_inline()
        C = s.outputs[0].output(0)

    # create tuning space
    cfg.define_knob("block_row_warps", [1, 2, 4])
    cfg.define_knob("block_col_warps", [1, 2, 4])
    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16])
    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16])
    cfg.define_knob("chunk", [1, 2, 4, 8])                  # K steps
    cfg.define_knob("offset", [0, 8])
    cfg.define_knob("offsetCS", [0, 8])
    cfg.define_knob("vec", [1, 2, 4, 8])

    # Ensure that the default parameters are applicable when autotvm is not in use
    if m_dim % 32 == 0 and n_dim % 8 == 0:
        cfg.define_knob("wmma_m", [32, 16, 8])
    elif m_dim % 16 == 0 and n_dim % 16 == 0:
        cfg.define_knob("wmma_m", [16, 8, 32])
    elif m_dim % 8 == 0 and n_dim % 32 == 0:
        cfg.define_knob("wmma_m", [8, 16, 32])

    warp_size = 32
    wmma_k = 16
    block_row_warps = cfg["block_row_warps"].val
    block_col_warps = cfg["block_col_warps"].val
    warp_row_tiles = cfg["warp_row_tiles"].val
    warp_col_tiles = cfg["warp_col_tiles"].val
    chunk = cfg["chunk"].val
    offset = cfg["offset"].val
    offsetCS = cfg["offsetCS"].val
    wmma_m = cfg["wmma_m"].val
    vec = cfg["vec"].val

    if wmma_m == 16:
        wmma_n = 16
    elif wmma_m == 8:
        wmma_n = 32
    elif wmma_m == 32:
        wmma_n = 8

    # Define the stride of intrin functions
    AS_align = chunk * wmma_k + offset
    BS_align = chunk * wmma_k + offset
    CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
    AS_stride = [AS_align, 1]
    BS_stride = [BS_align, 1]
    AF_stride = [wmma_k, 1]
    BF_stride = [wmma_k, 1]
    CF_stride = [warp_col_tiles * wmma_n, 1]
    CS_stride = [CS_align, 1]

    block_x = te.thread_axis("blockIdx.x")
    block_y = te.thread_axis("blockIdx.y")
    block_z = te.thread_axis("blockIdx.z")
    thread_x = te.thread_axis("threadIdx.x")
    thread_y = te.thread_axis("threadIdx.y")
    thread_z = te.thread_axis("threadIdx.z")

    # Schedule for dense computation
    block_factor_m = wmma_m * warp_row_tiles * block_row_warps
    block_factor_n = wmma_n * warp_col_tiles * block_col_warps
    b, m, n = C.op.axis
    block_i, bc = s[C].split(m, factor=block_factor_m)
    block_j, oc = s[C].split(n, factor=block_factor_n)
    s[C].reorder(b, block_i, block_j, bc, oc)
    t = s[C].fuse(bc, oc)
    t, vi = s[C].split(t, factor=vec)
    t, tx = s[C].split(t, factor=warp_size)
    t, ty = s[C].split(t, factor=block_row_warps)
    t, tz = s[C].split(t, factor=block_col_warps)
    s[C].bind(block_i, block_x)
    s[C].bind(block_j, block_y)
    # s[C].bind(b, block_z)
    s[C].bind(tz, thread_z) 
    s[C].bind(ty, thread_y) 
    s[C].bind(tx, thread_x) 
    s[C].vectorize(vi)

    # Schedule for wmma store
    s[CS].compute_at(s[C], block_j)
    bs, bb, oo = CS.op.axis
    s[CS].storage_align(bb, CS_align - 1, CS_align)
    bb, bbi = s[CS].split(bb, factor=wmma_m)
    oo, ooi = s[CS].split(oo, factor=wmma_n)
    bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
    oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
    s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)

    # Schedule for wmma computation
    s[CF].compute_at(s[CS], oo)
    bs, warp_i, warp_j = CF.op.axis
    warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) 
    warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
    (k,) = CF.op.reduce_axis
    k, _k = s[CF].split(k, factor=wmma_k)
    ko, ki = s[CF].split(k, factor=chunk)
    s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)
    s[CF].bind(ko, block_z)

    # Schedule for  wmma_matrix_a load
    s[AF].compute_at(s[CF], ki)
    bs, b, i = AF.op.axis
    b, b_ii = s[AF].split(b, factor=wmma_m)
    i, i_jj = s[AF].split(i, factor=wmma_k)
    s[AF].reorder(bs, b, i, b_ii, i_jj)

    # Schedule for  wmma_matrix_b load
    s[BF].compute_at(s[CF], ki)
    bs, o, i = BF.op.axis
    o, o_ii = s[BF].split(o, factor=wmma_n)
    i, i_ii = s[BF].split(i, factor=wmma_k)
    s[BF].reorder(bs, o, i, o_ii, i_ii)

    # Schedule for A's(B's) shared memory load
    def shared_shedule(stage, strides):
        s[stage].compute_at(s[CF], ko)
        bs, xo, yo = stage.op.axis
        s[stage].storage_align(xo, strides - 1, strides)
        t = s[stage].fuse(xo, yo)
        t, vi = s[stage].split(t, factor=vec)
        t, tx = s[stage].split(t, factor=warp_size)
        t, ty = s[stage].split(t, factor=block_row_warps)
        _, tz = s[stage].split(t, factor=block_col_warps)
        s[stage].bind(ty, thread_y)
        s[stage].bind(tz, thread_z)
        s[stage].bind(tx, thread_x)
        s[stage].vectorize(vi)

    shared_shedule(AS, AS_align)
    shared_shedule(BS, BS_align)

    shape = (wmma_m, wmma_n, wmma_k)
    # TODO: add checking here, datatype casting may cause precision loss
    in_dtype = "float16"
    AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
    BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
    k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
    CL_compute = te.compute(
        (wmma_m, wmma_n),
        lambda ii, jj: te.sum(
            AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[jj, k_gemm].astype(out_dtype),
            axis=k_gemm,
        ),
        name="CL_compute",
    )

    # lower the computation loops down to TensorCore hardware intrinsics
    # by mapping the dense tensorcore to tensor intrinsics
    s[AF].tensorize(
        b_ii,
        intrin_wmma_load_matrix_A(
            AF_stride,
            AS_stride,
            shape,
            "row_major",
            (wmma_m, wmma_k),
            (wmma_m, wmma_k),
            "float16",
        ),
    )
    s[BF].tensorize(
        o_ii,
        intrin_wmma_load_matrix_W(
            BF_stride,
            BS_stride,
            shape,
            "col_major",
            (wmma_n, wmma_k),
            (wmma_n, wmma_k),
            "float16",
        ),
    )
    s[CF].tensorize(
        _ii,
        intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape),
    )
    s[CS].tensorize(
        bbi,
        intrin_wmma_store_matrix(
            CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n)
        ),
    )

def _callback(op):
    if "batch_matmul_tensorcore" in op.tag:
        _schedule(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s