Motivation
Currently, our build
function has several stages (https://github.com/apache/tvm/blob/fb64be3f7807df18c2df6ebf5e68178e564ab0b4/python/tvm/driver/build_module.py#L140-L302):
- Lower the module (could be te.Schedule, tir.PrimFunc or IRModule) before being built into target.
- Annotate modules with targets
- Target aware lowering process (TIRToRuntime)
Limitations
However, such design has its limitations, and it hinders some of our development:
- Some of the passes in the lower process have to be hardware aware (e.g.
BF16legalizeCompute
andFP8LegalizeCompute
should know whether the target has native bf16/fp8 support or not. - Cannot bind different PrimFunc’s in an IRModule to different targets.
On-going efforts
- Considering most of the passes in the lowering process is target-agnostic, it should be safe to make the lowering function aware of targets. https://github.com/apache/tvm/pull/15183 will make this refactor.
- Refactor the lowering flow to make it more flexible (https://github.com/apache/tvm/pull/14985 might be related).
A more fundamental question is if we still need to make lowering a standalone process, if we make the lowering function target-aware, then it seems not necessary to decouple the lowering function with later passes in TIRToRuntime. I think the only reason we kept the lowering function is to be compatible with te.Schedule, which needs to be converted to IRModule first.
Possible next steps could be:
- Merge the passes in the lowering function and TIRToRuntime function.
- In
build
function, refactor the three-stage lowering process into a single flow. - Mark
lower
function as obsolete and should only kept them for legacy te.Schedule.