Hello, i’m trying to fuse to reduce stage with schedule functions, here is current tir:
inputs=[Tensor(shape=[10240, 512], op.name=placeholder)]
outputs=[Tensor(shape=[1, 512], op.name=T_divide)]
function:
#[version = "0.0.5"]
primfn(placeholder_1: handle, T_divide_1: handle) -> ()
attr = {"global_symbol": "fused_add_mean_add_mean", "tir.noalias": True}
buffers = {T_divide: Buffer(T_divide_2: Pointer(float32), float32, [1, 512], []),
placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10240, 512], [])}
buffer_map = {placeholder_1: placeholder, T_divide_1: T_divide} {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 16;
attr [T_add_red.rf: Pointer(float32)] "storage_scope" = "local";
allocate(T_add_red.rf, float32, [1]);
attr [reduce_temp0: handle] "storage_scope" = "local";
allocate(reduce_temp0, float32, [1]);
attr [T_add_red.rf_1: Pointer(float32)] "storage_scope" = "local";
allocate(T_add_red.rf_1, float32, [1]);
attr [reduce_temp0_1: handle] "storage_scope" = "local";
allocate(reduce_temp0_1, float32, [1]);
attr [IterVar(threadIdx.y: int32, [0:32], "ThreadIndex", "threadIdx.y")] "thread_extent" = 32 {
attr [IterVar(threadIdx.x: int32, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 {
T_add_red.rf[0] = 0f32
for (k0.outer: int32, 0, 320) {
T_add_red.rf[0] = ((float32*)T_add_red.rf[0] + ((float32*)placeholder_2[((((k0.outer*16384) + (threadIdx.x*512)) + (blockIdx.x*32)) + threadIdx.y)] + 1f32))
}
attr [meta[tir.CommReducer][0]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle);
@tir.tvm_thread_allreduce(1u32, (float32*)T_add_red.rf[0], True, reduce_temp0, threadIdx.x, dtype=handle)
}
reduce_temp0[0] = ((float32*)reduce_temp0[0]*9.76563e-05f32)
attr [IterVar(threadIdx.x, [0:32], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 {
T_add_red.rf_1[0] = 0f32
for (k0.outer_1: int32, 0, 320) {
T_add_red.rf_1[0] = ((float32*)T_add_red.rf_1[0] + (((float32*)placeholder_2[((((k0.outer_1*16384) + (threadIdx.x*512)) + (blockIdx.x*32)) + threadIdx.y)] + 1f32) + (float32*)reduce_temp0[0]))
}
attr [meta[tir.CommReducer][1]] "reduce_scope" = @tir.reinterpret(0u64, dtype=handle);
@tir.tvm_thread_allreduce(1u32, (float32*)T_add_red.rf_1[0], True, reduce_temp0_1, threadIdx.x, dtype=handle)
}
T_divide_2[((blockIdx.x*32) + threadIdx.y)] = ((float32*)reduce_temp0_1[0]*9.76563e-05f32)
}
}
Here two tvm_thread_allreduce bind to threadIdx.x of same range independently, can they share same threadIdx.x so that only one kernel will be created? Thankyou very much!