Here is a proposal @psrivas2 and I (mostly Prakalp) have written up on where we might draw some of the dividing lines for phases. This comes after taking a survey of most of the passes and what their dependencies are.
It may require some changes to pass implementations to realize this plan, as well as adopting the mechanism @tqchen proposed of using a module-level attribute to indicate the current phase. The staging here roughly gathers up the passes based on their intended functionality (going from high-level transformations to lower-level and platform-specific ones) and on dependencies.
Proposed Phases
Phase 0 (P0): Ingestion. These are passes that should be applied to a model once it is parsed.
-
Normalize
(this is already run after every pass and it will continue to be)
-
LambdaLifting
: Currently, the build process does not require LambdaLifting
to be run, but it would simplify many pass implementations to avoid having nested functions once a program has been parsed.
-
RewriteDataflowReshape
(change from the current implementation): Currently, this is included in the default build()
, but it is actually much higher-level than most of the other passes in the build. It is a nice simplification that replaces reshapes in TIR with the reshape
operator. It might be helpful to apply this from the start after parsing, though it could also be included in P1.
P1: Optimizations. These are higher-level passes that are, in general, optional to apply—later passes will not strictly depend on these. There are two subgroups we may consider:
- General optimizations: These are classic compiler optimizations that could be run at any point in compilation and possibly would make sense to run multiple times (e.g., to fixpoint) depending on what other passes produce. These generally do not require further explanation.
FoldConstant
-
CanonicalizeBindings
(Note: we might consider if this pass should be considered to be a stage, since unlike the others, the resulting program will have the invariant that no indirect bindings will be used, i.e., if some value is bound to v1
and v1
is bound to v2
, there will never be any reference to v2
but rather only to v1
directly. We could potentially have passes rely on that invariant.)
EliminateCommonSubexpr
DeadCodeElimination
FoldConstant
- Custom Optimizations: These are passes that implement domain-specific program transformations, usually relating to some specific TVM feature. Generally, these would not make sense to run more than once. We might consider imposing a pass ordering within this group, since many of these passes also do introduce invariants. However, there are no “hard dependencies” that we know of among the passes in this group.
-
BindParams
: Not necessarily an “optimization” per se but it is helpful on some targets. This one will not disrupt other passes.
-
FusionPass
: We propose gathering these passes into a “megapass” because they have hard dependencies only on each other and have related functionality. This would be a good candidate for being marked as a phase, since the resulting program will have fused primitive operators. (E.g., passes like CombineParallelMatmul
may not work on the program after running these passes).
AnnotateTIROpPattern
FuseOps
FuseTIR
-
BYOC
: It may also make sense to turn these three passes into a “megapass,” but that is not strictly necessary. However, these three passes have “soft” dependencies on each other: MergeCompositeFunctions
and RunCodegen
will do nothing if FuseOpsByPattern
has not been used. MergeCompositeFunctions
also must run before RunCodegen
or else it will have no effect. One complication is that FuseOpsByPattern
and MergeCompositeFunctions
introduce nested functions that RunCodegen
requires, which contradicts the P0 rule that nested functions will be lifted out. We may consider changing the implementations to use global definitions instead.
FuseOpsByPattern
-
MergeCompositeFunctions
(introduces inner functions right now so it could be a problem to enforce well formed checker constraints)
RunCodegen
-
LiftTransformParams
: This transforms the program in a meaningful way (turning weights into a tuple of parameters), so it might make sense to specify its order in the phases, though it generally is not disruptive to other passes.
-
ConvertLayout
: Useful for certain platforms but unlikely to disrupt other passes.
-
ToMixedPrecision
: Necessary on certain platforms but unlikely to disrupt other passes.
-
SplitCallTIRByPattern
: Operates only on PrimFuncs, so unlikely to disrupt other passes.
-
CombineParallelMatmul
: Not a phase in itself, little interaction with other passes.
P2.1: Early Build: These passes correspond to those used in build()
and have the collective effect of removing dataflow blocks, inserting explicit memory allocations, and turning Relax operators into PackedFunc
s.
-
RewriteDataflowReshape (currently included in build()
but we propose having it earlier)
-
LegalizeOps
(proposal): This is not currently included in build()
, but it is, in reality, necessary for any build using operators to succeed, so we propose including it there.
ToNonDataflow
CallTIRRewrite
StaticPlanBlockMemory
P2.2: Target-Specific Passes: These are passes that involve low-level target-specific information.
-
MetaSchedule
: This could be made a “megapass,” potentially. These passes all depend on LegalizeOps
. However, since they operate only on TIR PrimFunc
s, we might consider making LegalizeOps
and the MetaSchedule
passes an earlier phase, as there is nothing that inherently requires these passes to come after ToNonDataflow
, etc. (though the results of tuning will, of course, depend on target details).
MetaScheduleApplyDatabase
MetaScheduleTuneTIR
MetaScheduleTuneIRMod
-
RewriteCUDAGraph
: This has a dependency on StaticPlanBlockMemory
. Naturally, it is specific to CUDA.
- It is possible that other platform-specific passes may be added here.
P2.3: Lowering Builtins: These passes (part of the current build()
) cumulatively have the effect of lowering core language features and builtin operators into PackedFunc
s and preparing the IRModule
for VM code generation.
VMBuiltinLower
VMShapeLower
AttachGlobalSymbol
Still Uncategorized: Passes we have not yet characterized because they are experimental and highly subject to change:
Gradient
DecomposeOpsForInference
DecomposeOpsForTraining
AlterOpImpl
Enforcing Invariants
We could consider enforcing phase-specific invariants inside the well-formed checker as has been discussed. Here are conditions the well-formed check could enforce per phase in addition to checking that the AST is in normal form:
- After P0: No nested functions
- After P1: P1 does not introduce additional invariants.
- After P2.1:
- No dataflow blocks.
- No uses of
call_tir
or call_dps_packed
.
- No Relax tensor operators (only builtins).
- Explicit memory allocations / deallocations.
- After P2.2: P2.2 does not introduce additional invariants.
- After P2.3: We have a VM executable rather than a Relax program.
Main Questions to Address
- Do these categorizations all make sense? Are the distinctions perhaps ad hoc?
- How would we handle passes having hard dependencies on other passes? (As seen in some of the proposed “megapasses,” but also in cases like the
MetaSchedule
passes having a hard dependency on LegalizeOps
.) Should these simply be documented and have it be the user’s responsibility to run those passes? Should these hard dependencies always result in a specific stage to have the requirement be technically enforced?
- Are there low-level constructs from later passes that we should prohibit in earlier ones? Should it be permitted, for example, to have explicit memory allocations or deallocations before
StaticPlanBlockMemory
? Some passes may not be able to handle them. Should the well-formed checker enforce this?
- How should we handle the fact that the BYOC passes (namely
FuseOpsByPattern
and MergeCompositeFunctions
) introduce nested functions? Should we modify their implementations to avoid it or have those be in a separate stage where the nested function rule isn’t enforced?