In my schedule there are two ops. One is to calculate the result using gemm and the other is to reshape it . The function is like this:
for (i.outer.outer, 0, 98) {
for (j.outer.outer, 0, 16) {
for (ii, 0, 8) {
for (jj, 0, 8) {
gemm_C[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)]
}
}
}
}
for (n.oh.fused.ow.fused.outer.outer.outer, 0, 98) {
for (oc.outer.outer.outer, 0, 16) {
for (n.oh.fused.ow.fused.inner, 0, 8) {
for (oc.inner, 0, 8) {
output[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + (n.oh.fused.ow.fused.inner*128)) + (oc.outer.outer.outer*8)) + oc.inner)] = gemm_C[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + (oc.outer.outer.outer*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
}
}
}
}
I want these two operations to be in the same kernel. The gemm_C
result needs to be stored in the shared memory. I first bind the output axis to block and thread.
for (i, 0, 98) {
for (j, 0, 16) {
for (ii, 0, 8) {
for (jj, 0, 8) {
gemm_C[((((i*1024) + (j*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[((((i*1024) + (j*64)) + (ii*8)) + jj)]
}
}
}
}
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
for (n.oh.fused.ow.fused.inner, 0, 8) {
for (oc.inner, 0, 8) {
output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + (blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*1024) + (blockIdx.y*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
}
}
And then I try to set the scope for gemm_C
by using s[gemm_C].set_scope('shared')
or compute_at()
. Both methods will give the result like:
for (i, 0, 98) {
for (j, 0, (16 - blockIdx.y)) {
for (ii, 0, 8) {
for (jj, 0, 8) {
if (likely(((j + blockIdx.y) < 16))) {
gemm_C[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)]
}
}
}
}
}
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
// attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
for (n.oh.fused.ow.fused.inner, 0, 8) {
for (oc.inner, 0, 8) {
output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + (blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*(16 - blockIdx.y))*64) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
}
}
The j-axis
of gemm_C
is inferred to be (j, 0, (16-blockIdx.y)
. I can’t bind this axis to block_y
because of this weird inference.
Am I doing the correct things to achieve my goal? What are the possible reasons to cause iter_var
to be inferred like this? How should I solve this problem?