[Relay FuseOps] not working for kInjective -> commReduce ops

For a simple module with a kInjective → commReduce op

def @main(%a: Tensor[(5, 5), float32]) -> Tensor[(25), float32] {
  %0 = reshape(%a, newshape=[25, 1]) /* from_string */ /* ty=Tensor[(25, 1), float32] */;
  sum(%0, axis=[1]) /* from_string */ /* ty=Tensor[(25), float32] */

the FuseOps pass outputs

def @main(%a: Tensor[(5, 5), float32]) -> Tensor[(25), float32] {
  %0 = fn (%p0: Tensor[(5, 5), float32], Primitive=1) -> Tensor[(25, 1), float32] {
    reshape(%p0, newshape=[25, 1]) /* from_string */ /* ty=Tensor[(25, 1), float32] */
  %1 = %0(%a) /* ty=Tensor[(25, 1), float32] */;
  %2 = fn (%p01: Tensor[(25, 1), float32], Primitive=1) -> Tensor[(25), float32] {
    sum(%p01, axis=[1]) /* from_string */ /* ty=Tensor[(25), float32] */
  %2(%1) /* ty=Tensor[(25), float32] */

thus, codegen will not fuse reshape and sum, creating 2 separate kernels, one for each op. It is obvious that reshape can be inlined with sum and will save the cost of an extra kernel launch.

Is there a reason that these ops aren’t fused in FuseOps?

cc: @masahi @MarisaKirisame @jroesch @tqchen

hmm it’s not clear to me if we can generally fuse injective + reduce ops safely. GPU reduction, in particular, often needs to make multiple passes over inputs.

cc @altanh re gather + sum fusion in EmbeddingBag

I’m not sure I quite follow how multiple passes over inputs will cause fusion to be unsafe. Do reductions consume the memory location of its input?

No, I was thinking maybe fused injective ops would be computed multiple times as we make multiple passes (“safely” was not the best word to describe my concern, sorry). The “recompute” would simply be a indexing math for injective ops, though, so that might not be too significant.

You could try modifying kInjective to kCommReduce and see what happens. I’m not completely sure if this is the right change (I haven’t looked at this code for some time), though.

for the record, it looks like modifying kInjective to kCommReduce here does the right thing

thanks for clarifying, it would be great if we had benchmarks for model performance with

  1. current state of TVM
  2. allowing fusion between kInjective → kCommReduce

to test your theory, but its not critical :slight_smile: