[Relax TIR] how to lift up T.where?

Here is a code snippet:

    def Mod(A: T.Buffer((T.int64(36), T.int64(2), T.int64(32640), T.int64(2), T.int64(32)), "float16"), T_transpose: T.Buffer((T.int64(32640), T.int64(2), T.int64(32), T.int64(36), T.int64(2)), "float16")):
        # with T.block("root"):
        A_shared = T.alloc_buffer((T.int64(32640), T.int64(2), T.int64(2), T.int64(32), T.int64(32), T.int64(2)), "float16", scope="shared")
        for k in T.thread_binding(T.int64(32640), thread="blockIdx.x"):
            for i_0 in T.thread_binding(T.int64(2), thread="blockIdx.y"):
                for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    for ax3 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        for ax1_ax2_fused in T.unroll(T.int64(4)):
                            with T.block("A_shared"):
                                vi = T.axis.spatial(T.int64(36), i_0 * T.int64(32) + ax0)
                                vj = T.axis.spatial(T.int64(2), ax1_ax2_fused // T.int64(2))
                                vk = T.axis.spatial(T.int64(32640), k)
                                vf = T.axis.spatial(T.int64(2), ax1_ax2_fused % T.int64(2))
                                vg = T.axis.spatial(T.int64(32), ax3)
                                T.where(i_0 * T.int64(32) + ax0 < T.int64(36))
                                T.reads(A[vi, vj, vk, vf, vg])
                                T.writes(A_shared[vk, vf, vi // T.int64(32), vg, vi % T.int64(32), vj])
                                T.block_attr({"buffer_dim_align": [[0, 3, 66, 0]]})
                                A_shared[vk, vf, vi // T.int64(32), vg, vi % T.int64(32), vj] = A[vi, vj, vk, vf, vg]

When generate as cuda code, it would be:

extern "C" __global__ void __launch_bounds__(1024) Mod(half* __restrict__ A, half* __restrict__ T_transpose) {
  __shared__ half A_shared[4224];
  if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
    A_shared[((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2))] = A[((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x))];
  }
  if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
    A_shared[(((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2)) + 2112)] = A[(((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x)) + 32)];
  }
  if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
    A_shared[(((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2)) + 1)] = A[(((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x)) + 2088960)];
  }
  if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) {
    A_shared[(((((int)threadIdx.x) * 66) + (((int)threadIdx.y) * 2)) + 2113)] = A[(((((((int)blockIdx.y) * 133693440) + (((int)threadIdx.y) * 4177920)) + (((int)blockIdx.x) * 64)) + ((int)threadIdx.x)) + 2088992)];
  }

Since " if (((((int)blockIdx.y) * 8) + (((int)threadIdx.y) >> 2)) < 9) " condition judge is the same for all share mem store here, would it be possible that merge those 4 statement under one block of condition judgment?

Thx~

  1. We SHOULD do that but we cannot right now.
  2. During my experience, there is no significant performance gap no matter if we merge them. I guess that’s because NVCC may optimize it
1 Like