Aggressive operator fusion and its consequence of huge inlined TIR expression

Hi, I’ve been recently engaged with people at hummingbird project to help their TVM backend development.

One interesting thing about models coming from hummingbird is that the network structure changes as the size and characteristics of dataset change, since the network structure implicitly encodes structures of decision tree and its traversal.

In particular, as the size of dataset grows, one of fused functions can end up containing unbounded number of operations, which leads to a huge inlined one liner TIR expression and never ending compilation time.

For example, even with a tiny dataset of size (6, 10), one of the fused function ends up having more than 20 ops.

  %21 = fn (%p0: Tensor[(21, 1), float32], %p1: Tensor[(6, 10), float32], %p2: Tensor[(21), int64], %p3: Tensor[(6, 3), int64], %p4: Tensor[(6, 3), float32], %p5: Tensor[(6, 3), float32], %p6: Tensor[(6, 3), float32], %p7: Tensor[(1, 3), int64], %p8: Tensor[(21), float32], %p9: Tensor[(21), float32], %p10: Tensor[(21), float32], Primitive=1) -> Tensor[(6, 1, 3), float32] {
    %0 = gather(%p1, %p3, axis=1) /* ty=Tensor[(6, 3), float32] */;
    %1 = greater_equal(%0, %p4) /* ty=Tensor[(6, 3), bool] */;
    %2 = where(%1, %p5, %p6) /* ty=Tensor[(6, 3), float32] */;
    %3 = cast(%2, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
    %4 = add(%3, %p7) /* ty=Tensor[(6, 3), int64] */;
    %5 = reshape(%4, newshape=[-1]) /* ty=Tensor[(18), int64] */;
    %6 = take(%p2, %5, axis=0) /* ty=Tensor[(18), int64] */;
    %7 = reshape(%6, newshape=[-1, 3]) /* ty=Tensor[(6, 3), int64] */;
    %8 = gather(%p1, %7, axis=1) /* ty=Tensor[(6, 3), float32] */;
    %9 = take(%p8, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %10 = reshape(%9, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %11 = greater_equal(%8, %10) /* ty=Tensor[(6, 3), bool] */;
    %12 = take(%p9, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %13 = reshape(%12, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %14 = take(%p10, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %15 = reshape(%14, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %16 = where(%11, %13, %15) /* ty=Tensor[(6, 3), float32] */;
    %17 = cast(%16, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
    %18 = add(%17, %p7) /* ty=Tensor[(6, 3), int64] */;
    %19 = reshape(%18, newshape=[-1]) /* ty=Tensor[(18), int64] */;
    %20 = take(%p0, %19, axis=0) /* ty=Tensor[(18, 1), float32] */;
    reshape(%20, newshape=[6, 1, 3]) /* ty=Tensor[(6, 1, 3), float32] */
  };

The Relay fused function above gets translated into the following huge TIR, before the StroageFlatten pass:

buffer_realize T_reshape([0, 6], [0, 1], [0, 3]) {
  parallel (ax0.ax1.fused, 0, 6) {
    vectorized (ax2.inner, 0, 3) {
      T_reshape[ax0.ax1.fused, 0, ax2.inner] = placeholder[min(max((int64)0,
      (int64(select((int32((placeholder[floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6),
      placeholder[min(max((int64)0, (int64(select((int32((placeholder[floormod(floordiv(floormod(((
      floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +
      ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)),
      18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3)
      + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)],
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +  ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
      placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)]] >= placeholder[min(max((int64)0,
      (int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((
      floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3),
      6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)], placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3)
      + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
      placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)])) != 0), placeholder[min(max((int64)0,
      (int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3)
      + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18),
      3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) != 0),
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused +
      0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)],
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18)
      , 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) + placeholder[0, floormod(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])),
      (int64)20)], placeholder[min(max((int64)0, (int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod(((
      (ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(
      floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)
      , 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)
      *3) + ax2.inner), 18), 3)), 18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +
      ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3)
      + ax2.inner), 18), 3)), 18), 3)], placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)
      , 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
      placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)])) + placeholder[0, floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3)])), (int64)20), 0]
    }
  }
}

For a dataset of size (20, 10), a fused function gets 70-90 ops inside it, the resulting TIR is so huge it takes forever to compile. In particular, compilation seems to get stuck at StroageFlatten pass (it forever does VisitExpr or VisitStmt on some nodes).

There are more discussion at hummingbird repo here https://github.com/microsoft/hummingbird/issues/232

My questions are:

  1. Is the TIR above reasonable for the given Relay function above? I think fusing 20 or so operations are totally reasonable but I was surprised to see such a complicated expression coming out of the lowering process.
  2. The only solution I think of is to limit the size of fused function configurable. Are there other approaches?

cc @tqchen @zhiics @kevinthesun @jwfromm

There are a few things that might be helpful.

First of all, try to limit the fusion depth will certainly help in this case(or mark where as opaque).

A second potential place for improvement would be trying to fold some of the diamond patterns. In particular, TIR does not treat the expr as a DAG, but instead they are treated as a tree. This means that if a node get referenced twice(e.g. in the case of where), it will they will get expanded twice.

We intentionally expand the index expressions so they can be simplified more easily.

Now that we have a good amount of LetNode support, we could instead implement some of the value via the let node instead let x = lhs let y = rhs in select(lhs < value, lhs, rhs). There might be a few ways to get these expressions. For example, one could run a common sub-expr folding early(into let) only on the computation but also try to avoid the index expression when possible.

2 Likes

Makes sense, I also had a feeling that there could be a lot of duplication going on.

Does duplication at TIR level means the resulting LLVM IR is also duplicated? Are we relying on LLVM to dedup? Otherwise I think even if we limit the fuse depth or wait long enough for lowering to finish, the generated code is suboptimal.

Generating let sounds interesting, I want to look into it.

likely the CSE in LLVM will remove the dup, so it is mainly a compilation efficiency issue.

1 Like

Since the hummingbird use case is completely memory bound, I think it is better to fuse as much as possible, rather than setting a fuse depth bound to workaround compilation time.

Looking at the fused function below a bit more and comparing with the Algorithm 2 in the hummingbird paper, I think the source of duplication is %5 below which is referenced 4 times.

  %21 = fn (%p0: Tensor[(21, 1), float32], %p1: Tensor[(6, 10), float32], %p2: Tensor[(21), int64], %p3: Tensor[(6, 3), int64], %p4: Tensor[(6, 3), float32], %p5: Tensor[(6, 3), float32], %p6: Tensor[(6, 3), float32], %p7: Tensor[(1, 3), int64], %p8: Tensor[(21), float32], %p9: Tensor[(21), float32], %p10: Tensor[(21), float32], Primitive=1) -> Tensor[(6, 1, 3), float32] {
    %0 = gather(%p1, %p3, axis=1) /* ty=Tensor[(6, 3), float32] */;
    %1 = greater_equal(%0, %p4) /* ty=Tensor[(6, 3), bool] */;
    %2 = where(%1, %p5, %p6) /* ty=Tensor[(6, 3), float32] */;
    %3 = cast(%2, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
    %4 = add(%3, %p7) /* ty=Tensor[(6, 3), int64] */;
    %5 = reshape(%4, newshape=[-1]) /* ty=Tensor[(18), int64] */;
    %6 = take(%p2, %5, axis=0) /* ty=Tensor[(18), int64] */;
    %7 = reshape(%6, newshape=[-1, 3]) /* ty=Tensor[(6, 3), int64] */;
    %8 = gather(%p1, %7, axis=1) /* ty=Tensor[(6, 3), float32] */;
    %9 = take(%p8, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %10 = reshape(%9, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %11 = greater_equal(%8, %10) /* ty=Tensor[(6, 3), bool] */;
    %12 = take(%p9, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %13 = reshape(%12, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %14 = take(%p10, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %15 = reshape(%14, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %16 = where(%11, %13, %15) /* ty=Tensor[(6, 3), float32] */;
    %17 = cast(%16, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
    %18 = add(%17, %p7) /* ty=Tensor[(6, 3), int64] */;
    %19 = reshape(%18, newshape=[-1]) /* ty=Tensor[(18), int64] */;
    %20 = take(%p0, %19, axis=0) /* ty=Tensor[(18, 1), float32] */;
    reshape(%20, newshape=[6, 1, 3]) /* ty=Tensor[(6, 1, 3), float32] */
  };

%5 corresponds to the updated node indices at one tree level (T_i in the paper), which is used as indices of gather op to look up, for each interior node,

  • Feature index
  • Threshold value
  • Left child index
  • Right child index

I’m hoping that instead of referencing %5 four times directly, let bounding %5 once (via Relay Let) and referencing the bound var should fix the duplication of TIR. Does that sound about right? This way, we don’t need to insert let at TIR level, or rely on CSE at TIR or LLVM.

I quickly tried A Normal form transformation after op fusion but it seems ToANormalForm pass doesn’t generate let binding inside function. But even if it worked, A Normal transform is overkill for this use case.

If the above reasoning sounds good, I’m thinking about introducing a let insertion pass that walks over fused functions and add let binding if a certain expression is referenced more than configurable times (4 in the hummingbird case above).

@tqchen @junrushao @MarisaKirisame

1 Like

Due to the current mechanism of the relay to TIR operator generation. ANF in relay won’t translate to let binding in the TIR, so CSE in TIR might be the right path here

2 Likes

hello,I wonder how to get the output of TIR from relay as above