Can two reduce stage fused?

Hello, i’m trying to fuse to reduce stage with schedule functions, here is current tir:

inputs=[Tensor(shape=[10240, 512], op.name=placeholder)]
outputs=[Tensor(shape=[1, 512], op.name=T_divide)]
function:
#[version = "0.0.5"]
primfn(placeholder_1: handle, T_divide_1: handle) -> ()
  attr = {"global_symbol": "fused_add_mean_add_mean", "tir.noalias": True}
  buffers = {T_divide: Buffer(T_divide_2: Pointer(float32), float32, [1, 512], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10240, 512], [])}
  buffer_map = {placeholder_1: placeholder, T_divide_1: T_divide} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 16;
  attr [T_add_red.rf: Pointer(float32)] "storage_scope" = "local";
  allocate(T_add_red.rf, float32, [1]);
  attr [reduce_temp0: handle] "storage_scope" = "local";
  allocate(reduce_temp0, float32, [1]);
  attr [T_add_red.rf_1: Pointer(float32)] "storage_scope" = "local";
  allocate(T_add_red.rf_1, float32, [1]);
  attr [reduce_temp0_1: handle] "storage_scope" = "local";
  allocate(reduce_temp0_1, float32, [1]);
  attr [IterVar(threadIdx.y: int32, [0:32], "ThreadIndex", "threadIdx.y")] "thread_extent" = 32 {
    attr [IterVar(threadIdx.x: int32, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 {
      T_add_red.rf[0] = 0f32
      for (k0.outer: int32, 0, 320) {
        T_add_red.rf[0] = ((float32*)T_add_red.rf[0] + ((float32*)placeholder_2[((((k0.outer*16384) + (threadIdx.x*512)) + (blockIdx.x*32)) + threadIdx.y)] + 1f32))
      }
      attr [meta[tir.CommReducer][0]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle);
      @tir.tvm_thread_allreduce(1u32, (float32*)T_add_red.rf[0], True, reduce_temp0, threadIdx.x, dtype=handle)
    }
    reduce_temp0[0] = ((float32*)reduce_temp0[0]*9.76563e-05f32)
    attr [IterVar(threadIdx.x, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 {
      T_add_red.rf_1[0] = 0f32
      for (k0.outer_1: int32, 0, 320) {
        T_add_red.rf_1[0] = ((float32*)T_add_red.rf_1[0] + (((float32*)placeholder_2[((((k0.outer_1*16384) + (threadIdx.x*512)) + (blockIdx.x*32)) + threadIdx.y)] + 1f32) + (float32*)reduce_temp0[0]))
      }
      attr [meta[tir.CommReducer][1]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle);
      @tir.tvm_thread_allreduce(1u32, (float32*)T_add_red.rf_1[0], True, reduce_temp0_1, threadIdx.x, dtype=handle)
    }
    T_divide_2[((blockIdx.x*32) + threadIdx.y)] = ((float32*)reduce_temp0_1[0]*9.76563e-05f32)
}
}

Here two tvm_thread_allreduce bind to threadIdx.x of same range independently, can they share same threadIdx.x so that only one kernel will be created? Thankyou very much!

1 Like

hello ,I also have the same quuestion, May I ask if you have the answer now