Here is a code snippet:
def Mod(A: T.Buffer((T.int64(36), T.int64(2), T.int64(32640), T.int64(2), T.int64(32)), "float16"), T_transpose: T.Buffer((T.int64(32640), T.int64(2), T.int64(32), T.int64(36), T.int64(2)), "float16")):
# with T.block("root"):
A_shared = T.alloc_buffer((T.int64(32640), T.int64(2), T.int64(2), T.int64(32), T.int64(32), T.int64(2)), "float16", scope="shared")
for k in T.thread_binding(T.int64(32640), thread="blockIdx.x"):
for i_0 in T.thread_binding(T.int64(2), thread="blockIdx.y"):
for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
for ax3 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
for ax1_ax2_fused in T.unroll(T.int64(4)):
with T.block("A_shared"):
vi = T.axis.spatial(T.int64(36), i_0 * T.int64(32) + ax0)
vj = T.axis.spatial(T.int64(2), ax1_ax2_fused // T.int64(2))
vk = T.axis.spatial(T.int64(32640), k)
vf = T.axis.spatial(T.int64(2), ax1_ax2_fused % T.int64(2))
vg = T.axis.spatial(T.int64(32), ax3)
T.where(i_0 * T.int64(32) + ax0 < T.int64(36))
T.reads(A[vi, vj, vk, vf, vg])
T.writes(A_shared[vk, vf, vi // T.int64(32), vg, vi % T.int64(32), vj])
T.block_attr({"buffer_dim_align": [[0, 3, 66, 0]]})
A_shared[vk, vf, vi // T.int64(32), vg, vi % T.int64(32), vj] = A[vi, vj, vk, vf, vg]
When generate as cuda code, it would be:
extern "C" __global__ void __launch_bounds__(1024) Mod(half* __restrict__ A, half* __restrict__ T_transpose) {
__shared__ half A_shared[4224];
if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
A_shared[((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2))] = A[((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x))];
}
if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
A_shared[(((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2)) + 2112)] = A[(((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x)) + 32)];
}
if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
A_shared[(((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2)) + 1)] = A[(((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x)) + 2088960)];
}
if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
A_shared[(((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2)) + 2113)] = A[(((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x)) + 2088992)];
}
Since " if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) " condition judge is the same for all share mem store here, would it be possible that merge those 4 statement under one block of condition judgment?
Thx~