[Discussion] More general cross thread allreduce

Currently, we have to use separate threads for spatial and reduction axes to do cross-thread reduction. Having one thread that both binds to spatial and reduction will cause incorrect codegen, even if they are in effect the same thread mapping.

For example, if we use 16 threads to reduce 4x4 tensor A into 4x1 (torch.sum(A, dim=-1))

A repro: allreduce.py · GitHub

Why do we probably want this?

Suppose A is in a 4x4 packed layout tensor, and we’d like to reduce 8 rows in a thread block with 64 threads per CTA, then to obtain maximum coalesced read (contiguous threads access contiguous positions), we want the thread mapping to be also in 4x4 packed layout, but there’s no way to arrange tx and ty to achieve this while tx and ty are either only spatial or reduction at the same time.

4 Likes

This is a great direction, perhaps we can use affine map to handle this

Great observation and thanks for bringing it up! It worths some try here. While it looks to me that currently the intrinsic tir::builtin::tvm_thread_allreduce() does not support such semantics (reduction only goes through a part of tx), so supporting this is not pretty trivial, and would require us to revisit the intrinsic interface as well as its lowering pass.