Hi, I’ve been recently engaged with people at hummingbird project to help their TVM backend development.
One interesting thing about models coming from hummingbird is that the network structure changes as the size and characteristics of dataset change, since the network structure implicitly encodes structures of decision tree and its traversal.
In particular, as the size of dataset grows, one of fused functions can end up containing unbounded number of operations, which leads to a huge inlined one liner TIR expression and never ending compilation time.
For example, even with a tiny dataset of size (6, 10), one of the fused function ends up having more than 20 ops.
%21 = fn (%p0: Tensor[(21, 1), float32], %p1: Tensor[(6, 10), float32], %p2: Tensor[(21), int64], %p3: Tensor[(6, 3), int64], %p4: Tensor[(6, 3), float32], %p5: Tensor[(6, 3), float32], %p6: Tensor[(6, 3), float32], %p7: Tensor[(1, 3), int64], %p8: Tensor[(21), float32], %p9: Tensor[(21), float32], %p10: Tensor[(21), float32], Primitive=1) -> Tensor[(6, 1, 3), float32] {
%0 = gather(%p1, %p3, axis=1) /* ty=Tensor[(6, 3), float32] */;
%1 = greater_equal(%0, %p4) /* ty=Tensor[(6, 3), bool] */;
%2 = where(%1, %p5, %p6) /* ty=Tensor[(6, 3), float32] */;
%3 = cast(%2, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
%4 = add(%3, %p7) /* ty=Tensor[(6, 3), int64] */;
%5 = reshape(%4, newshape=[-1]) /* ty=Tensor[(18), int64] */;
%6 = take(%p2, %5, axis=0) /* ty=Tensor[(18), int64] */;
%7 = reshape(%6, newshape=[-1, 3]) /* ty=Tensor[(6, 3), int64] */;
%8 = gather(%p1, %7, axis=1) /* ty=Tensor[(6, 3), float32] */;
%9 = take(%p8, %5, axis=0) /* ty=Tensor[(18), float32] */;
%10 = reshape(%9, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
%11 = greater_equal(%8, %10) /* ty=Tensor[(6, 3), bool] */;
%12 = take(%p9, %5, axis=0) /* ty=Tensor[(18), float32] */;
%13 = reshape(%12, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
%14 = take(%p10, %5, axis=0) /* ty=Tensor[(18), float32] */;
%15 = reshape(%14, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
%16 = where(%11, %13, %15) /* ty=Tensor[(6, 3), float32] */;
%17 = cast(%16, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
%18 = add(%17, %p7) /* ty=Tensor[(6, 3), int64] */;
%19 = reshape(%18, newshape=[-1]) /* ty=Tensor[(18), int64] */;
%20 = take(%p0, %19, axis=0) /* ty=Tensor[(18, 1), float32] */;
reshape(%20, newshape=[6, 1, 3]) /* ty=Tensor[(6, 1, 3), float32] */
};
The Relay fused function above gets translated into the following huge TIR, before the StroageFlatten
pass:
buffer_realize T_reshape([0, 6], [0, 1], [0, 3]) {
parallel (ax0.ax1.fused, 0, 6) {
vectorized (ax2.inner, 0, 3) {
T_reshape[ax0.ax1.fused, 0, ax2.inner] = placeholder[min(max((int64)0,
(int64(select((int32((placeholder[floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6),
placeholder[min(max((int64)0, (int64(select((int32((placeholder[floormod(floordiv(floormod(((
floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +
ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)),
18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3)
+ ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)],
placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)]] >= placeholder[min(max((int64)0,
(int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((
floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(
floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3),
6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(
floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)], placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3)
+ floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)])) != 0), placeholder[min(max((int64)0,
(int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((floormod(floordiv(
floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3)
+ ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18),
3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) != 0),
placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused +
0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)],
placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((
ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18)
, 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) + placeholder[0, floormod(floormod(((floormod(floordiv(
floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])),
(int64)20)], placeholder[min(max((int64)0, (int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((
ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod(((
(ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(
floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)
, 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)
*3) + ax2.inner), 18), 3)), 18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +
ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3)
+ ax2.inner), 18), 3)), 18), 3)], placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)
, 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((
ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)])) + placeholder[0, floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
18), 3)])), (int64)20), 0]
}
}
}
For a dataset of size (20, 10), a fused function gets 70-90 ops inside it, the resulting TIR is so huge it takes forever to compile. In particular, compilation seems to get stuck at StroageFlatten
pass (it forever does VisitExpr
or VisitStmt
on some nodes).
There are more discussion at hummingbird repo here https://github.com/microsoft/hummingbird/issues/232
My questions are:
- Is the TIR above reasonable for the given Relay function above? I think fusing 20 or so operations are totally reasonable but I was surprised to see such a complicated expression coming out of the lowering process.
- The only solution I think of is to limit the size of fused function configurable. Are there other approaches?