[RFC] Add While loop node to TIR

Motivation

TIR For node is very simple, it always does a fixed number of iterations. But some irregular algorithm, like NMS, need the number of iterations that is dependent on input. This is because as soon as we find max_output_size number of boxes, there is no need to find more boxes. Without While loop, however, we cannot realize this “early exit”, so our NMS is always slower than frameworks.

Binary search is another interesting use case of While loop. Binary search is needed to add support for numpy style searchsorted function. @mbrookhart also needs binary search to improve his merge sort, using “merge-path” algorithm.

Design alternatives

As @tqchen pointed out, we have two alternatives for how we might approach adding While loop:

  • V0: Put the condition into the For node as an additional field (what I’ve implemented)
  • V1: Introduce a separate While node

The pros of V0 are:

  • Implementation is very simple, minimal change
  • If we want a single loop abstraction that encompasses all kinds of loops, to do more advanced analysis (than what TVM does today) in a unified manner, this approach might be better

Interestingly, ispc, which offers implicit SPMD programming model to programmers, takes this approach. For and While loop are represented by the same IR construct. Although While loop usually only makes sense in a serial loop, ispc allows programmers to use While loop and still vectorizes it (!!), using sophisticated analysis and codegen (It does so by carefully updating a mask per vector, to keep track of which lanes are active and only does the work for the active lane). So for ispc, unifying For and While makes a lot of sense.

The cons of V0 are:

  • It adds a new field to For node which only makes sense for while-loop case (kind == “serial”). The other fields of For node are irrelevant to While loop (begin, extent, etc)
  • Related to above, users always have to specify the loop upper bound, since this is required by For.
  • I need to write a lot of if(op->test) { ... }, to check if this is a while loop or not.
  • It complicates analysis on For node. While loop doesn’t need any analysis (unless we want to vectorize it like ispc does), so this coupling is not nice.

The pros of V1 are:

  • A clean decoupling of For and While.
  • Codegen for While loop is simpler. It doesn’t need phi node.
  • The user API would also be simpler (just needs condition)

The cons of V1 are the exact opposite of the pros of V0 (more effort, no unified analysis etc).

Discussion

I think for us and for now, V1 is the clear winner. Any other thoughts to favor V0? Personally I do find the design of ispc elegant and I’m amazed at what it does, but I think such implicit SPMD + auto vectorization is out of reach for TVM for now.

@tqchen @junrushao1994 @vinx13 @mbrookhart @zhiics @kevinthesun @trevor-m @anijain2305

2 Likes

I agree that V1 is better. This makes the IR more clear without adding complexity to the existing For node that only works with serial case.

Vote for V1. It creates a clear separation for different use cases.

Thanks for the RFC @masahi!

V0 suggests to add a field of For node, while V1 introduces a new While node. Both have pros and cons.

Adding a new While node requires all the TIR passes and codegen to handle it correctly, and there will be quite a bit of labor intensive work. So V0 is definitely a better choice in this case.

Adding a field to For definitely avoids modifying so many passes, but the problem is that if we keep adding fields to For, there might be a breaking point that we lose track of the info that how many passes don’t deal well with what additional fields. However, adding a single field doesn’t really seem to be that bad IMHO.

In conclusion, I think both approaches have pros and cons, and would love to hear about more opinions :slight_smile:

@junrushao1994 Do we expect to modify a lot of passes? I think we just need to update the base visitor, and the rest of passes can ignore the new While. We do need to add new codegen functions, but that’s only LLVM and C source codegen which was already done in my PoC PR.

The work required to add While node is mostly boilerplate stuff, I hope.

Not sure if this is relevant to this discussion, but Rust in it’s intermediate MIR reduces all loops (there are three types in proper Rust) to goto constructs, which also unifies if/else.

I’d like to vote for adding a while node. We can have a pass to convert While to For though if we want to benefit from the analysis.

To add another “what is done elsewhere” datapoint: TorchScript uses a combined while/for prim::Loop. As far as I understand it expects to combine loops / normalize break in a way that might produce “mixed” loops, but I don’t know the exact rationale/how implementations use it.

Great discussions, to offer a bit of thoughts about separation vs unification. The design choice would depend on set of transformations we have in mind.

T0: CFG style analysis

Most of the traditional compiler (e.g. LLVM) performs CFG analysis to handle things like loop invariant, common sub-expression, and dead-code. Like @jknight said, in most of these cases a goto like construct is used to unify While, For and If. Because all of them can be analyzed using data flow analysis in a unified way. Reducing everything to a single construct would simplify such analysis (so that a single type of analysis can be performed for all of them). Notably, traditional CFG analysis also mostly applied to loops that are serial(no parallelization and context thread).

T1: Transformations of Regular Loops

In the specific case of TIR, we are also performing transformations of “regular loop”. e.g. The loop iteration domains that correspond to nice compact integer sets. We can perform schedule transformations like tiling, compute location changing. Additionally, the loop them-selfs can correspond to GPU context thread, and the analysis still holds for such kind of generalizations.

In the case of T1, most transformations and proofs cannot be applied for While loops, or require non-trivial generalization. This would mean different code path for the case of common regular For and While.

The Current Situation

Because we are designing with automation(and broader transformation space) in mind, TIR focues on T1 style passes and will add more transformation in similar style. Note that traditional compilers normally do not perform T1, or would have to use analysis to decompile regular loop structures before performing such analysis. In the meantime, because we are using CFG-based compiler as backend(e.g. LLVM), we do not have to perform T0 style analysis because they can be mostly handled by the backend.

The Impact to Pass Upgrading

Both V0 and V1 will likely impact our passes. Since this is an addition to the possible semantics of the IR.

  • For passes that do not depend on control structure, e.g. simplify the expressions in the IR, using default visitor While loops certainly works.
  • For passes that depend on control structure e.g. vectorization a block, we will need to add handling of while block and skip vectorization in such cases, regardless of V0 or V1 representation.

A typical rule of thumb could be: if the pass requires handling of IfThenElse(as a result being control dependent), then likely we also need to think about handling of While. Careful checks of the related passes are needed, so that we do not produce code that could results in bugs.

1 Like

If both proposals could require careful updating of passes, then the “simplicity” argument for V0 diminishes, and it gives more pros to V1, I think.

With V1 we don’t have to touch existing For node handling code, while V0 requires adding many messy case analysis to the existing code, on whether it is a regular For node or While node in disguise.

@masahi Thanks for the proposal and it sounds great.

My two cents would be it would be pragmatic to start with V1, however, we would need to think enabling vectorization (and similar optimization passes that works for regular iteration domains) to have a path to get there.

Looking forward, While “While” nodes are serial in nature (as the bounds could not be determined at compile time), we could create a regular block of For node (the block size could be determined via the Tuner/Scheduler ? ) inside using predication.

Vectorization of while loop (using predication aka mask) is an interesting topic, but I don’t see a good use case for TVM. Vectorizing binary search comes to my mind if one wants a super fast, multi threaded, vectorized searchsorted function on CPU :slightly_smiling_face:

@masahi , Well – I was thinking of the use case you mentioned – nms – in the journey to find the early exit, one can simply travel in steps larger than 1 (say 4 – this being the block size) and that could be vectorized given the body of the blocked-for can use predicates to handle the tail. What do you think ?

No, the outer loop of nms, which I’m replacing with While loop, needs to be done sequentially, since the loop is done in a sorted order of box scores and one box can invalidate unbounded number of succeeding boxes, including the boxes in the same vector (if vectorized). If one box is invalidated by an earlier, higher scoring box, it shouldn’t invalidate other boxes.

A PR posted [TIR] Add TIR While node by masahi · Pull Request #7425 · apache/tvm · GitHub

Also checkout discussion on MLIR forum where they discussed adding while IR to their control flow dialect [RFC] Add scf.while to scf dialect - MLIR - LLVM Discussion Forums

While loop is used extensively in their TACO style sparse codegen, see their recent update MLIR Support for Sparse Tensors - #17 by aartbik - MLIR - LLVM Discussion Forums

I’d love to see TACO style sparse compiler realized in TIR, building on While loop in this proposal. Let me know if people have thought @tqchen @junrushao1994 @tkonolige

1 Like

@masahi I think TACO in TVM would be really cool. But, right no, I’m not sure how necessary it is. I think we only have two ops that could benefit from it. And I haven’t seen much interest in sparsity on this forum.

1 Like