Extra loop scale when do tensorize with tir?

I recently attempted to use the scheduling primitives providied by tensor ir to re-implement this TVM, but I got some issues when do tensorize.

https://tvm.apache.org/docs/how_to/optimize_operators/opt_conv_tensorcore.html

The message that I got when I do tensorize ldmatrix_wmma_a_fragment is

        at /workspace/v-leiwang3/tvm/src/tir/schedule/concrete_schedule.cc:712
  0: tvm::tir::Tensorize(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, tvm::tir::TensorIntrin const&, bool)
        at /workspace/v-leiwang3/tvm/src/tir/schedule/primitive/blockize_tensorize.cc:586
  File "/workspace/v-leiwang3/tvm/src/tir/schedule/primitive/blockize_tensorize.cc", line 586
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (comparator.rhs_buffer_map_.count(desc)) is false: 

I found that other parts of the wmma tensorization, such as wmma_ldmatrix_b, wmma_fill, and wmma_store, work fine except for this specific part. Initially, I thought the issue may have been with my schedule, but I have double-checked and am confident that it is correct.

Upon further investigation, I discovered that the problem was caused by an affine transform division, SubspaceDivide. The statement that needs to be tensorized before this affine transform is:

The stmt that need to be tensorized before this affine transform is

for (ax2, 0, 16) {
  block Apad_shared_wmma.matrix_a(iter_var(v0, range(min=0, ext=16)), iter_var(v1, range(min=0, ext=16)), iter_var(v2, range(min=0, ext=16)), iter_var(v3, range(min=0, ext=16)), iter_var(v4, range(min=0, ext=16)), iter_var(v5, range(min=0, ext=16))) {
    bind(v0, (((n_0_0*8) + (n_0_1*2)) + ax0))
    bind(v1, (h + kh))
    bind(v2, (w + kw))
    bind(v3, ((ic_0*2) + ic_1))
    bind(v4, ax1)
    bind(v5, ax2)
    reads([Apad_shared[v0, v1, v2, v3, v4, v5]])
    writes([Apad_shared_wmma.matrix_a[v0, v1, v2, v3, v4, v5]])
    Apad_shared_wmma.matrix_a[v0, v1, v2, v3, v4, v5] = Apad_shared[v0, v1, v2, v3, v4, v5]
  }
}

which appears to be correct, but after this affine transform, the stmt becomes:

Apad_shared_wmma.matrix_a_o(iter_var(v0, range(min=0, ext=16)), iter_var(v1, range(min=0, ext=16)), iter_var(v2, range(min=0, ext=16)), iter_var(v3, range(min=0, ext=16)), iter_var(v4_o, range(min=0, ext=1)), iter_var(v5_o, range(min=0, ext=1))) {
  bind(v0, (((n_0_0*8) + (n_0_1*2)) + ax0))
  bind(v1, ((kh*14) + h))
  bind(v2, ((kw*14) + w))
  bind(v3, ((ic_0*2) + ic_1))
  bind(v4_o, 0)
  bind(v5_o, 0)
  where(((((kh*14) + h) < 16) && (((kw*14) + w) < 16)))
  reads([Apad_shared[v0, v1, v2, v3, 0:16, 0:16]])
  writes([Apad_shared_wmma.matrix_a[v0, v1, v2, v3, 0:16, 0:16]])
  for (ax1, 0, 16) {
    for (ax2, 0, 16) {
      block Apad_shared_wmma.matrix_a(iter_var(v4_i, range(min=0, ext=16)), iter_var(v5_i, range(min=0, ext=16))) {
        bind(v4_i, ax1)
        bind(v5_i, ax2)
        reads([Apad_shared[v0, v1, v2, v3, v4_i, v5_i]])
        writes([Apad_shared_wmma.matrix_a[v0, v1, v2, v3, v4_i, v5_i]])
        Apad_shared_wmma.matrix_a[v0, v1, v2, v3, v4_i, v5_i] = Apad_shared[v0, v1, v2, v3, v4_i, v5_i]
      }
    }
  }
}

the loop kh and kw have a extra scale 14, and the stmt also has an extra where condition expr, which makes this failed for tensorize.

I haven’t gone into the details of why this part happens yet, so I don’t now is it a bug of SubspaceDivide or I just do a stupid coding, but I’ll keep working on it, it will be helpful if somebody give me some adivce.

code to reproduce this issue: reproduce_issue_20230114.py · GitHub