GPU thread binding and iter_var infer

In my schedule there are two ops. One is to calculate the result using gemm and the other is to reshape it . The function is like this:

  for (i.outer.outer, 0, 98) {
    for (j.outer.outer, 0, 16) {
      for (ii, 0, 8) {
        for (jj, 0, 8) {
          gemm_C[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)]
        }
      }
    }
  }

  for (n.oh.fused.ow.fused.outer.outer.outer, 0, 98) {
    for (oc.outer.outer.outer, 0, 16) {
      for (n.oh.fused.ow.fused.inner, 0, 8) {
        for (oc.inner, 0, 8) {
          output[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + (n.oh.fused.ow.fused.inner*128)) + (oc.outer.outer.outer*8)) + oc.inner)] = gemm_C[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + (oc.outer.outer.outer*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
        }
      }
    }
  }

I want these two operations to be in the same kernel. The gemm_C result needs to be stored in the shared memory. I first bind the output axis to block and thread.

 for (i, 0, 98) {
    for (j, 0, 16) {
      for (ii, 0, 8) {
        for (jj, 0, 8) {
          gemm_C[((((i*1024) + (j*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[((((i*1024) + (j*64)) + (ii*8)) + jj)]
        }
      }
    }
  }

  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  for (n.oh.fused.ow.fused.inner, 0, 8) {
    for (oc.inner, 0, 8) {
      output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + (blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*1024) + (blockIdx.y*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
    }
  }

And then I try to set the scope for gemm_C by using s[gemm_C].set_scope('shared') or compute_at(). Both methods will give the result like:

for (i, 0, 98) {
    for (j, 0, (16 - blockIdx.y)) {
      for (ii, 0, 8) {
        for (jj, 0, 8) {
          if (likely(((j + blockIdx.y) < 16))) {
            gemm_C[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)]
          }
        }
      }
    }
  }

  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  for (n.oh.fused.ow.fused.inner, 0, 8) {
    for (oc.inner, 0, 8) {
      output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + (blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*(16 - blockIdx.y))*64) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
    }
  }

The j-axis of gemm_C is inferred to be (j, 0, (16-blockIdx.y). I can’t bind this axis to block_y because of this weird inference.

Am I doing the correct things to achieve my goal? What are the possible reasons to cause iter_var to be inferred like this? How should I solve this problem?