Here is a proposal @psrivas2 and I (mostly Prakalp) have written up on where we might draw some of the dividing lines for phases. This comes after taking a survey of most of the passes and what their dependencies are.
It may require some changes to pass implementations to realize this plan, as well as adopting the mechanism @tqchen proposed of using a module-level attribute to indicate the current phase. The staging here roughly gathers up the passes based on their intended functionality (going from high-level transformations to lower-level and platform-specific ones) and on dependencies.
Proposed Phases
Phase 0 (P0): Ingestion. These are passes that should be applied to a model once it is parsed.
-
Normalize(this is already run after every pass and it will continue to be) -
LambdaLifting: Currently, the build process does not requireLambdaLiftingto be run, but it would simplify many pass implementations to avoid having nested functions once a program has been parsed. -
RewriteDataflowReshape(change from the current implementation): Currently, this is included in the defaultbuild(), but it is actually much higher-level than most of the other passes in the build. It is a nice simplification that replaces reshapes in TIR with thereshapeoperator. It might be helpful to apply this from the start after parsing, though it could also be included in P1.
P1: Optimizations. These are higher-level passes that are, in general, optional to apply—later passes will not strictly depend on these. There are two subgroups we may consider:
- General optimizations: These are classic compiler optimizations that could be run at any point in compilation and possibly would make sense to run multiple times (e.g., to fixpoint) depending on what other passes produce. These generally do not require further explanation.
FoldConstant-
CanonicalizeBindings(Note: we might consider if this pass should be considered to be a stage, since unlike the others, the resulting program will have the invariant that no indirect bindings will be used, i.e., if some value is bound tov1andv1is bound tov2, there will never be any reference tov2but rather only tov1directly. We could potentially have passes rely on that invariant.) EliminateCommonSubexprDeadCodeEliminationFoldConstant
- Custom Optimizations: These are passes that implement domain-specific program transformations, usually relating to some specific TVM feature. Generally, these would not make sense to run more than once. We might consider imposing a pass ordering within this group, since many of these passes also do introduce invariants. However, there are no “hard dependencies” that we know of among the passes in this group.
-
BindParams: Not necessarily an “optimization” per se but it is helpful on some targets. This one will not disrupt other passes. -
FusionPass: We propose gathering these passes into a “megapass” because they have hard dependencies only on each other and have related functionality. This would be a good candidate for being marked as a phase, since the resulting program will have fused primitive operators. (E.g., passes likeCombineParallelMatmulmay not work on the program after running these passes).AnnotateTIROpPatternFuseOpsFuseTIR
-
BYOC: It may also make sense to turn these three passes into a “megapass,” but that is not strictly necessary. However, these three passes have “soft” dependencies on each other:MergeCompositeFunctionsandRunCodegenwill do nothing ifFuseOpsByPatternhas not been used.MergeCompositeFunctionsalso must run beforeRunCodegenor else it will have no effect. One complication is thatFuseOpsByPatternandMergeCompositeFunctionsintroduce nested functions thatRunCodegenrequires, which contradicts the P0 rule that nested functions will be lifted out. We may consider changing the implementations to use global definitions instead.FuseOpsByPattern-
MergeCompositeFunctions(introduces inner functions right now so it could be a problem to enforce well formed checker constraints) RunCodegen
-
LiftTransformParams: This transforms the program in a meaningful way (turning weights into a tuple of parameters), so it might make sense to specify its order in the phases, though it generally is not disruptive to other passes. -
ConvertLayout: Useful for certain platforms but unlikely to disrupt other passes. -
ToMixedPrecision: Necessary on certain platforms but unlikely to disrupt other passes. -
SplitCallTIRByPattern: Operates only on PrimFuncs, so unlikely to disrupt other passes. -
CombineParallelMatmul: Not a phase in itself, little interaction with other passes.
-
P2.1: Early Build: These passes correspond to those used in build() and have the collective effect of removing dataflow blocks, inserting explicit memory allocations, and turning Relax operators into PackedFuncs.
-
RewriteDataflowReshape(currently included inbuild()but we propose having it earlier) -
LegalizeOps(proposal): This is not currently included inbuild(), but it is, in reality, necessary for any build using operators to succeed, so we propose including it there. ToNonDataflowCallTIRRewriteStaticPlanBlockMemory
P2.2: Target-Specific Passes: These are passes that involve low-level target-specific information.
-
MetaSchedule: This could be made a “megapass,” potentially. These passes all depend onLegalizeOps. However, since they operate only on TIRPrimFuncs, we might consider makingLegalizeOpsand theMetaSchedulepasses an earlier phase, as there is nothing that inherently requires these passes to come afterToNonDataflow, etc. (though the results of tuning will, of course, depend on target details).MetaScheduleApplyDatabaseMetaScheduleTuneTIRMetaScheduleTuneIRMod
-
RewriteCUDAGraph: This has a dependency onStaticPlanBlockMemory. Naturally, it is specific to CUDA. - It is possible that other platform-specific passes may be added here.
P2.3: Lowering Builtins: These passes (part of the current build()) cumulatively have the effect of lowering core language features and builtin operators into PackedFuncs and preparing the IRModule for VM code generation.
VMBuiltinLowerVMShapeLowerAttachGlobalSymbol
Still Uncategorized: Passes we have not yet characterized because they are experimental and highly subject to change:
GradientDecomposeOpsForInferenceDecomposeOpsForTrainingAlterOpImpl
Enforcing Invariants
We could consider enforcing phase-specific invariants inside the well-formed checker as has been discussed. Here are conditions the well-formed check could enforce per phase in addition to checking that the AST is in normal form:
- After P0: No nested functions
- After P1: P1 does not introduce additional invariants.
- After P2.1:
- No dataflow blocks.
- No uses of
call_tirorcall_dps_packed. - No Relax tensor operators (only builtins).
- Explicit memory allocations / deallocations.
- After P2.2: P2.2 does not introduce additional invariants.
- After P2.3: We have a VM executable rather than a Relax program.
Main Questions to Address
- Do these categorizations all make sense? Are the distinctions perhaps ad hoc?
- How would we handle passes having hard dependencies on other passes? (As seen in some of the proposed “megapasses,” but also in cases like the
MetaSchedulepasses having a hard dependency onLegalizeOps.) Should these simply be documented and have it be the user’s responsibility to run those passes? Should these hard dependencies always result in a specific stage to have the requirement be technically enforced? - Are there low-level constructs from later passes that we should prohibit in earlier ones? Should it be permitted, for example, to have explicit memory allocations or deallocations before
StaticPlanBlockMemory? Some passes may not be able to handle them. Should the well-formed checker enforce this? - How should we handle the fact that the BYOC passes (namely
FuseOpsByPatternandMergeCompositeFunctions) introduce nested functions? Should we modify their implementations to avoid it or have those be in a separate stage where the nested function rule isn’t enforced?