[Unity] Dealing with Phase Ordering

As we discussed in the Apr. 11, 2023, TVM Unity Open Development Meeting, one of the issues in the current Relax implementation is that the compiler at various points relies on a specific ordering of passes (phase ordering) without clearly advertising this fact. Some of this phase ordering exists for good engineering reasons, so the existence of these dependencies among passes is not itself an issue; rather, what we should address is the fact that many of these dependencies are subtle and are not documented.

Some examples of phase ordering:

  • build() in vm_build.py uses the ToNonDataflow pass to eliminate dataflow blocks, meaning that the low-level code generation passes do not have to deal with dataflow blocks.
  • VMShapeLower has a comment indicating that it does not deal with nested functions and expects LambdaLifting to be called first, though it is not included in the default build().
  • The default VM code generator expects operators to be legalized, even though this is not included in the default build().
  • The LegalizeOps pass cannot handle certain operators. For example, invoke_closure, which may be introduced by LambdaLifting, does not have a legalization, meaning that LegalizeOps should be called before LambdaLifting (this is not documented anywhere).
  • Somewhat related: The tensor_to_shape operator is lowered into a builtin by DecomposeOpsForInference and not by VMLowerBuiltin. This operator thus creates a hidden dependency on DecomposeOpsForInference, which is not documented anywhere.
  • Normally, we expect TIR functions to be called only via call_tir (though this is not enforced in the compiler); however, CallTIRRewrite lowers call_tir operator calls into explicit tensor allocations and direct calls to the PrimFuncs (it would be good to note which passes should expect to deal with such calls and which should not).
  • Additionally, phase ordering has resulted in headaches in my purity tracking PR, as having to reason about purity makes it very difficult to deal with lower-level code generation (e.g., lowering operators to builtins). This problem was solved by stripping away the purity checks during the default build(), but even that creates some issues in cases like using a BYOC custom code generator (see the RunCodegen pass in that PR).

In the community meeting, we proposed certain measures that we can take to deal with this complexity:

  1. We should certainly document our expectations as far as pass ordering goes. We could write it out in some file, e.g., have a src/relax/transform/README.md to explain this.
  2. For a technical measure, @tqchen proposed using module attributes to indicate what “phase” of compilation the module is on and check that any passes invoked correspond to the correct phase. We could use the well-formedness checker to enforce invariants pertaining to the different phases.
  3. Another technical measure might be to use the required_passes field in the current Relax pass infrastructure (presently unused), though the attendees to the discussion were not in favor of having passes be automatically (and thus “silently”) run, which may surprise users. If we use this field to indicate dependencies, it should ideally be used to give warnings rather than run passes automatically.

To pursue any of these measures, however, we would have to agree on what phases we should have in compilation and which passes should act as transitions between these phases. At the meeting, we noted that at present there are two de facto phases: high-level transformations on a model and then low-level code generation (in build()). However, we may want to have a finer division of stages (an example was brought up involving GPU code generation, where some options might apply. BYOC might also factor into this discussion, per the example that came up in the purity tracking PR). Additionally, we should decide on what dependencies are acceptable within a phase (should there be an enforced ordering? It might be reasonable to rely on it in build()).

4 Likes

Thanks @slyubomirsky or documenting these dependencies.

Overall, I think we could proceed with option 2. Option 1, which is collecting expectations on these passes, would happen as part of option 2. We should definitely put that in the README.md when it is ready. Option 3 as you mentioned was not a popular choice among the attendees.

For option 2, I suppose the next step here is to take a first stab at classifying the existing passes (~35 in tree passes) into different categories (ideally <=3). This could be a community effort or may be one of us could do it and post here for feedback from pass owners.

  • The LegalizeOps pass cannot handle certain operators. For example, invoke_closure, which may be introduced by LambdaLifting, does not have a legalization, meaning that LegalizeOps should be called before LambdaLifting (this is not documented anywhere).

I don’t see a problem in calling LambdaLifting before LegalizeOps. Legalize ops would simply ignore these invoke_closure type ops to be lowered later by VMBuiltinLower.

1 Like

It’s not an issue per se, but we should either clarify the ordering expectation or handle that case. This would be a good example of something to avoid in passes: a “silent” ordering dependency (with no fundamental reason for it).

One thing that should be figured out is what categories are necessary. Right now, there is the de facto categorization imposed by the VM build(), where certain low-level passes happen without dataflow blocks. Some of these low-level passes also expect certain operators to have been lowered to VM builtins. I am curious about BYOC passes, since those happen before the VM build().

@tqchen proposed having more than two stages, so I’m curious what stage might come between the initial input program and lower-level code generation.

Here is a proposal @psrivas2 and I (mostly Prakalp) have written up on where we might draw some of the dividing lines for phases. This comes after taking a survey of most of the passes and what their dependencies are.

It may require some changes to pass implementations to realize this plan, as well as adopting the mechanism @tqchen proposed of using a module-level attribute to indicate the current phase. The staging here roughly gathers up the passes based on their intended functionality (going from high-level transformations to lower-level and platform-specific ones) and on dependencies.

Proposed Phases

Phase 0 (P0): Ingestion. These are passes that should be applied to a model once it is parsed.

  • Normalize (this is already run after every pass and it will continue to be)
  • LambdaLifting: Currently, the build process does not require LambdaLifting to be run, but it would simplify many pass implementations to avoid having nested functions once a program has been parsed.
  • RewriteDataflowReshape (change from the current implementation): Currently, this is included in the default build(), but it is actually much higher-level than most of the other passes in the build. It is a nice simplification that replaces reshapes in TIR with the reshape operator. It might be helpful to apply this from the start after parsing, though it could also be included in P1.

P1: Optimizations. These are higher-level passes that are, in general, optional to apply—later passes will not strictly depend on these. There are two subgroups we may consider:

  1. General optimizations: These are classic compiler optimizations that could be run at any point in compilation and possibly would make sense to run multiple times (e.g., to fixpoint) depending on what other passes produce. These generally do not require further explanation.
    • FoldConstant
    • CanonicalizeBindings (Note: we might consider if this pass should be considered to be a stage, since unlike the others, the resulting program will have the invariant that no indirect bindings will be used, i.e., if some value is bound to v1 and v1 is bound to v2, there will never be any reference to v2 but rather only to v1 directly. We could potentially have passes rely on that invariant.)
    • EliminateCommonSubexpr
    • DeadCodeElimination
    • FoldConstant
  2. Custom Optimizations: These are passes that implement domain-specific program transformations, usually relating to some specific TVM feature. Generally, these would not make sense to run more than once. We might consider imposing a pass ordering within this group, since many of these passes also do introduce invariants. However, there are no “hard dependencies” that we know of among the passes in this group.
    • BindParams: Not necessarily an “optimization” per se but it is helpful on some targets. This one will not disrupt other passes.
    • FusionPass: We propose gathering these passes into a “megapass” because they have hard dependencies only on each other and have related functionality. This would be a good candidate for being marked as a phase, since the resulting program will have fused primitive operators. (E.g., passes like CombineParallelMatmul may not work on the program after running these passes).
      • AnnotateTIROpPattern
      • FuseOps
      • FuseTIR
    • BYOC: It may also make sense to turn these three passes into a “megapass,” but that is not strictly necessary. However, these three passes have “soft” dependencies on each other: MergeCompositeFunctions and RunCodegen will do nothing if FuseOpsByPattern has not been used. MergeCompositeFunctions also must run before RunCodegen or else it will have no effect. One complication is that FuseOpsByPattern and MergeCompositeFunctions introduce nested functions that RunCodegen requires, which contradicts the P0 rule that nested functions will be lifted out. We may consider changing the implementations to use global definitions instead.
      • FuseOpsByPattern
      • MergeCompositeFunctions (introduces inner functions right now so it could be a problem to enforce well formed checker constraints)
      • RunCodegen
    • LiftTransformParams: This transforms the program in a meaningful way (turning weights into a tuple of parameters), so it might make sense to specify its order in the phases, though it generally is not disruptive to other passes.
    • ConvertLayout: Useful for certain platforms but unlikely to disrupt other passes.
    • ToMixedPrecision: Necessary on certain platforms but unlikely to disrupt other passes.
    • SplitCallTIRByPattern: Operates only on PrimFuncs, so unlikely to disrupt other passes.
    • CombineParallelMatmul: Not a phase in itself, little interaction with other passes.

P2.1: Early Build: These passes correspond to those used in build() and have the collective effect of removing dataflow blocks, inserting explicit memory allocations, and turning Relax operators into PackedFuncs.

  • RewriteDataflowReshape (currently included in build() but we propose having it earlier)
  • LegalizeOps (proposal): This is not currently included in build(), but it is, in reality, necessary for any build using operators to succeed, so we propose including it there.
  • ToNonDataflow
  • CallTIRRewrite
  • StaticPlanBlockMemory

P2.2: Target-Specific Passes: These are passes that involve low-level target-specific information.

  • MetaSchedule: This could be made a “megapass,” potentially. These passes all depend on LegalizeOps. However, since they operate only on TIR PrimFuncs, we might consider making LegalizeOps and the MetaSchedule passes an earlier phase, as there is nothing that inherently requires these passes to come after ToNonDataflow, etc. (though the results of tuning will, of course, depend on target details).
    • MetaScheduleApplyDatabase
    • MetaScheduleTuneTIR
    • MetaScheduleTuneIRMod
  • RewriteCUDAGraph: This has a dependency on StaticPlanBlockMemory. Naturally, it is specific to CUDA.
  • It is possible that other platform-specific passes may be added here.

P2.3: Lowering Builtins: These passes (part of the current build()) cumulatively have the effect of lowering core language features and builtin operators into PackedFuncs and preparing the IRModule for VM code generation.

  • VMBuiltinLower
  • VMShapeLower
  • AttachGlobalSymbol

Still Uncategorized: Passes we have not yet characterized because they are experimental and highly subject to change:

  • Gradient
  • DecomposeOpsForInference
  • DecomposeOpsForTraining
  • AlterOpImpl

Enforcing Invariants

We could consider enforcing phase-specific invariants inside the well-formed checker as has been discussed. Here are conditions the well-formed check could enforce per phase in addition to checking that the AST is in normal form:

  • After P0: No nested functions
  • After P1: P1 does not introduce additional invariants.
  • After P2.1:
    • No dataflow blocks.
    • No uses of call_tir or call_dps_packed.
    • No Relax tensor operators (only builtins).
    • Explicit memory allocations / deallocations.
  • After P2.2: P2.2 does not introduce additional invariants.
  • After P2.3: We have a VM executable rather than a Relax program.

Main Questions to Address

  1. Do these categorizations all make sense? Are the distinctions perhaps ad hoc?
  2. How would we handle passes having hard dependencies on other passes? (As seen in some of the proposed “megapasses,” but also in cases like the MetaSchedule passes having a hard dependency on LegalizeOps.) Should these simply be documented and have it be the user’s responsibility to run those passes? Should these hard dependencies always result in a specific stage to have the requirement be technically enforced?
  3. Are there low-level constructs from later passes that we should prohibit in earlier ones? Should it be permitted, for example, to have explicit memory allocations or deallocations before StaticPlanBlockMemory? Some passes may not be able to handle them. Should the well-formed checker enforce this?
  4. How should we handle the fact that the BYOC passes (namely FuseOpsByPattern and MergeCompositeFunctions) introduce nested functions? Should we modify their implementations to avoid it or have those be in a separate stage where the nested function rule isn’t enforced?
1 Like

May 16 Unity Meeting Notes

@tqchen’s concerns: Can we simplify the flow for the user? (Keep a simple mental model.) Customizing for particular targets is useful, even if it could mean some “mixed phases.” Customization does mean exposing some internals, e.g., to allow users to substitute their own libraries.

It would be best to keep things to three phases: import → optimize (may be target-dependent) → lowering (remove dataflow and lower to builtins, etc.) → low-level customization → build the executable

@psrivas2: Perhaps we could just fold LambdaLifting into the normalizer. This would eliminate the “import” phase and would also allow pass writers to introduce nested functions and have them work.

  • Disadvantage: This would result in more work after each pass.
  • Advantage: Would allow passes to be kept to function/dataflow level while still using nested functions.

We don’t really imagine this case coming up often, so that might be a circumstance when we would ask for users to manually run LambdaLifting or another pass afterwards.

@MasterJH5574: Why RewriteDataflowReshape was designed to happen late in compilation: TIR functions doing reshapes can be introduced by legalization, which could result in unnecessary copying. Instead, this pass will turn these functions into reshapes, which are lowered into NDArray views by other passes.

  • This seems a little bit circular. Perhaps we shouldn’t legalize these reshapes into PrimFuncs in the first place.
  • We didn’t want to introduce a new “builtin reshape” operator to complicate things, so that is why there is this odd dance of legalizing reshapes into PrimFuncs and then rewriting the PrimFuncs into reshapes. There are some complex interactions with fusion, so that is why we separate those steps out.

Thank you for the proposal! Here are my thoughts:

  • Personally, I would prefer to apply optimization passes like MetaSchedule before the build so that we can keep the build simple and minimum. I assume regular users would not want to look inside of build to play around with their optimization passes of their interests. So, if we put passes like MetaSchedule, which users would want to customize its setting all the time, build would need to provide an interface to expose such configurations and I’m concerned if this would not be scalable and easy to use.
  • I think mega pass is interesting idea to force the ordering within the same phase. I’m wondering how we define this mega pass. Can we define this in Python so that we can conveniently group passes? Also, I think current BYOC may require some refactoring since the passes we apply are different between cutlass and others. (see cutlass vs tensorrt)

Sorry seems that I entered a verbose mode and did not explain things perfectly clear in the dev meeting today. Let me debrief things about pass RewriteDataflowReshape here via text.


In short, pass RewriteDataflowReshape is a pass which we don’t want to expose to users and want to apply at a stage where

  • the computational graph will no longer be rewritten (e.g., no fusion will happen anymore), and
  • all high-level Relax ops are lowered to TIR PrimFuncs.

Considering the existing design and implementation of the optimization and build stage, the requirements above are only met inside relax.build(...), since we now have the assumption that people can feel free applying any pass (can be existing in Relax codebase, can be user-written in Python) before build. Thus, we apply the RewriteDataflowReshape pass in relax.build.

To provide more details,

  • for the first requirement, we need this since if we apply this pass when the computational graph update is not finalized, applying this pass early will possibly prevent some optimizations intended by the subsequent passes. One example is that assuming we have a sequent of call_tirs that can be fused together, with one of them in the middle being a reshape. If we apply RewriteDataflowReshape before fusion, the fusion pass will then no longer fuse all of them together.
  • the second requirement more comes from the current implementation of RewriteDataflowReshape. At this moment, we will rewrite eligible call_tir(reshape_tir_func, ..) to relax.reshape(...). Implemented in this way, if we apply the pass when there still exists high-level operator calls in Relax functions, when we further legalize the high-level operators, the relax.reshape(...) we rewrote will be turned back into call_tir(reshape_tir_func, ...), which offsets what RewriteDataflowReshape did. Nevertheless, we can bring a new builtin operator as the rewrite result of RewriteDataflowReshape to get around this issue, so it is not a blocker.

For the consideration above, I feel it might not be ideal to say bring RewriteDataflowReshape to the very beginning of our optimization and build pipeline.

1 Like

Regarding the overall phase organization, I’m good with the optimization stage, and am imagining we could have a “lowering” phase which applies a sequence of “lowering” passes before we do build. This sequence contains

  • what we now have in relax.build, and
  • possibly LegalizeOps and also, some tuning pass (or default schedule pass) in the front.

And the lowering can be target-dependent. The lowering here is conceptually similar to TIR PrimFunc lowering, just with an additional target-dependent factor.

The lowering stage generally serves as a black box to users. All passes here are well baked, and people don’t need to know what happened when lowering (and also I believe when we wrote these passes in relax.build, we are not expecting them to be used explicitly by users). Because of this, I think it is fine we have dependency between passes in the lowering stage, as we don’t expose the internal flow to users. As proposed, we can do phase partition internally for these passes under the scope of “lowering.”

As discussed in the dev meeting today, we can have a customization stage after lowering, right before entering the final build. One very minor concern about this is I remember we don’t yet have examples on build-stage customization. Perhaps we can omit the customization here for now and bring it back when the needs raise (?) But I’m open about this and it also makes sense if we reserve a customization stage here.

@MasterJH5574, thank you for the more detailed explanation. I think you’re right, then, that RewriteDataflowReshape should probably happen during build().

@sunggg, thank you for your reply. “Megapasses” could be implemented by just applying the passes in their current sequence (e.g., just use a Sequential pass made up of those passes), though @tqchen pointed out that there may be situations where users would want the finer-grained control, so maybe we shouldn’t force users to use the megapasses. In any case, users should be alerted clearly to the ordering constraints on those passes.

Regarding making build() simple and monolithic, would it make sense to have legalization and metascheduling as a phase that immediately precedes the build, then? @MasterJH5574 points out that we don’t have examples of customization within the last stage of the build, so we could leave that point until (or if) it comes up.

One thing that i usually find helpful is to think about the phases as barriers(between phases), where each barrier have some prebuilt pass that are not part of customization in our case, we have

input -> barrierA(ingestion) -> stepA -> barrierB(lower) 
-> stepB -> barrierC(build) -> library

I find it is helpful to think about each barrier as some prebuild functions

  • build takes anything to pass barrierC
  • lower takes anything to get them to pass barrierB
  • canonicalize or import(TBD) takes anything to barrierA
1 Like

Another rule we could enforce in the well-formedness checker: Not permitting direct calls of PrimFuncs (i.e., not through call_tir) before P2.1. This one came up during the purity tracking discussion and I remembered after rereading those notes.

Now that the purity tracking PR has been merged, we should probably include it in the phase order: We use RemovePurityChecking after ToNonDataflow, since it’s not very helpful for the low-level codegen to have to reason about purity.

Also, here is one issue that just occurred to me: LambdaLifting replaces some function calls with the invocation of VMClosures, which seems like a relatively low-level construct. Does that cause any difficulties for Phase 1 passes? That should probably be something we alert users about in any case.

I’ve posted a tracking issue for phase ordering on Github so that we could discuss implementation issues, since there seems to be broad agreement that we should have some sort of ordering and enforce it to some degree. It would be good to focus on the aspects that are less certain and discuss how they affect the implementation.