Hi, I want to use Tensorize to store reduce’s result, but i got error(Cannot mix cross thread reduction with Tensorize) when building the schedule. If I delete the sch[real_output].tensorize() , the error will go away. I was wondering that tensorize and kCrossThreadReduction cannot be used at the same time? and if so, how can i use tensorize in reduce? thanks a lot!
data_in = op.input_tensors[0]
data_out = op.output(0)
warp_size = 64
IL = sch.cache_read(data_in, "local", [op])
if len(sch[data_out].op.axis) > 0:
all_reduce = False
num_thread = 64
target = tvm.target.Target.current()
if target and (target.kind.name == "opencl" or target.kind.name == "metal"):
# without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py
# don't know why
num_thread = 16
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
else:
all_reduce = True
num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
thread_x = te.thread_axis((0, warp_size), "threadIdx.x")
# Fuse and refactor the reduce axis
fused_reduce = sch[data_out].fuse(
*[sch[data_out].op.reduce_axis[i] for i in range(len(sch[data_out].op.reduce_axis))]
)
ko, ki = sch[data_out].split(fused_reduce, factor=warp_size) #ko:outer(61),ki:inner(1024)
data_out_rf = sch.rfactor(data_out, ki) #把ki单独拎出来,并对其做归并,包含两个操作
sch[IL].compute_at(sch[data_out_rf], sch[data_out_rf].op.reduce_axis[0])
sch[IL].tensorize(
IL.op.axis[-1],
intrin_xxxx_load_tensor_image(data_in.storage_scope, "local",
data_in.dtype, (1,), (0,), (warp_size,))
)
tx = sch[data_out].op.reduce_axis[0] #此时的sch[data_out].op只有一个reduce_axis,负责计算输出
sch[data_out].bind(tx, thread_x) #把对1024个数据进行求和的任务绑定到每个线程上,实现并行化计算
sch[data_out_rf].compute_at(sch[data_out], tx) #将sch[data_out_rf]附着到目标sch[data_out]指定的tx方向,减少了一个for循环
real_output = data_out
result = sch[real_output].set_store_predicate(thread_x.equal(0))
sch[real_output].tensorize(
tx,
intrin_xxxx_store_tensor_image(
"local",
real_output.storage_scope,
real_output.dtype,
(1,)
)
)