RFC for Relay MLIR Frontend
Authors: Yanan Cao, Yong Wu Contributors: Yida Wang, Haichen Shen, Yao Wang
Summary
We propose a solution that can give TVM/Relay top-notch model/op coverage for TensorFlow with affordable effort.
Motivation
TensorFlow, as the most dominant machine learning framework, has a scarily large number of operations, in fact over 1200+ of them, out of which 500+ are commonly used computational ops. TVM currently has hand-written lowering functions for 145 of them from TF to Relay. These 145 operations are enough to support some well known computer vision models like ResNet and SSD. This is far from enough when users want to bring custom models with extra operations.
Proposal
Instead of converting from TF→Relay directly, we can consider using XLA HLO as an intermediate step to reduce engineering cost, namely TF→XLA HLO→Relay .
XLA is a Tensorflow-oriented machine learning compiler with XLA HLO as its frontend IR. HLO is designed exactly to address the problem of TF op set overwhelming a compiler.
HLO has a few characteristics that, I believe, make it a great target for Relay to import from:
- HLO Operation Set is tiny : At the time of writing, there are <100 operations. Compared to number of TensorFlow operations, it is much easier to implement lowering from HLO→Relay than TF→Relay directly.
- HLO Operation Set is stable : Rarely are any new operations added to HLO, saving us the trouble of playing catch-up game.
- TensorFlow GraphDef → HLO lowering support is top-notch : In TensorFlow repo, there are symbolic execution kernels that can lower ~400 TensorFlow op to XLA HLO, covering vast majority of mathematical operations model creators use. They are all well tested and maintained.
- Other than TensorFlow, major frameworks like PyTorch, JAX and Julia all have different degree of XLA lowering support. This means we can get additional model/op coverage for these 3 frameworks at almost no additional cost .
However, XLA HLO isn’t all perfect, it requires strict static shapes for all tensors and operations. This limits model coverage and sometimes requires modification to model in order to compile. Luckily, with MLIR and its HLO Dialect, representing dynamic shapes with HLO ops is now possible. More details can be found in lowering path design.
TensorFlow Lowering Paths Design
Instead of lowering from TensorFlow to HLO directly, we could alternatively leverage MLIR and its HLO dialect for faster implementation and support for model dynamism.
To be more specific, MLIR is a compiler infrastructure that helps define dialect IRs and their corresponding compilers. In our context, the most relevant dialects are:
- MLIR-TF Dialect is an isomorphic representation of TensorFlow GraphDef.
- MLIR-HLO Dialect is a near-isomorphic representation of XLA HLO. It is different from HLO in that it 1) represents dynamic shape 2) Has a slight different set of operations, but can still round-trip (convert to and from) vanilla XLA HLO.
- MLIR-Relay Dialect should be isomorphic to Relay with a trivial conversion between the two.
Components needed:
- tf-mlir-translate tool: Needs to be enhanced to functionalize Switch/Merge control flow nodes and automatically identify input/output node names/shapes.
- tf-opt tool: Contains lowerings from MLIR TF to MLIR HLO. This is in active development by Google. There are currently ~150 operations implemented.
- MLIR Relay Dialect and its thin conversion to Relay
- Lowerings from MLIR HLO to MLIR Relay: Using MLIR infrastructure, this can be implemented with less effort than writing custom converter in option 1.
XLA HLO → Relay
To prove viability, we created a table with a rough mapping from each XLA HLO operand to their Relay equivalences. This table is created after referencing HLO operations and their semantics, Relay TOPI, HLO OpCode Definitions, XlaBuilder.
In following table, you can see all XLA HLO ops, in fact even more than what HLO has. This is because MLIR HLO and vanilla HLO ops, though largely overlap, still have some differences. Since I believe MLIR-HLO lowering path is better due to support of dynamic shapes, I listed all the vanilla HLO-only ops at bottom of the table with comment claiming that they are not in XLA Builder API.
XLA HLO | Needed? | Equivalent Relay Ops | Comment |
---|---|---|---|
AfterAll | No | Control dependency operation, not applicable to Relay | |
AllReduce | No | Multi-core only, not applicable to single-core inference | |
AllToAll | No | Multi-core only, not applicable to single-core inference | |
BatchNormGrad | No | training related | |
BatchNormInference | Yes | batch_norm | |
BatchNormTraining | No | training related | |
BitcastConvertType | Yes | reinterpret | Maps to BitcastConvert in hlo_opcode |
Broadcast | Yes | broadcast_to | slight difference in “shape” input, maps to broadcast in hlo_opcode |
BroadcastInDim | Yes | broadcast_to | maps to broadcast in hlo_opcode |
Call | Yes | ||
Cholesky | Yes | composable with existing ops, formula here (https://www.tensorflow.org/xla/operation_semantics#cholesky) | |
Clamp | Yes | clip | |
Collapse | Yes | reshape | Only in XlaBuilder API, maps to reshape in hlo_opcode |
CollectivePermute | No | Multi-core only, not applicable to single-core inference | |
Concatenate | Yes | concatenate | |
Conditional | Yes | relay.expr.If | |
Conv | Yes | conv2d_{nchw, hwcn}, conv1d, conv3d | Limited support for 3-4-5 ranks |
ConvWithGeneralPadding | Yes | conv2d_{nchw, hwcn}, conv1d, conv3d | Limited support for 3-4-5 ranks, maps to convolution in hlo_opcode |
ConvertElementType | Yes | cast | Only in XlaBuilder API, maps to “convert” in hlo_opcode |
CrossReplicaSum | No | Only in XlaBuilder API, multi-core only, not applicable to single-core inference | |
CustomCall | No | This is a way to invoke arbitrary CPU code, very rarely used. | |
Dot | Yes | multiply | Only in XlaBuilder API, the “dot” in hlo_opcode actually maps to DotGeneral |
DotGeneral | Yes | reshape+multiply | |
DynamicSlice | Yes | strided_slice | |
DynamicUpdateSlice | Yes | May need to add new op into relay | |
Add | Yes | add | |
Sub | Yes | subtract | |
Mul | Yes | mul | Maps to multiply in hlo_opcode |
Div | Yes | divide | |
Rem | Yes | divide, subtract | Maps to “remainder” in HloOpcode |
Max | Yes | max | |
Min | Yes | min | |
And | Yes | logical_and | |
Or | Yes | logical_or | |
Eq | Yes | equal | Maps to “compare” in HloOpcode |
Ne | Yes | not_equal | Maps to “compare” in HloOpcode |
Ge | Yes | greater, equal | Maps to “compare” in HloOpcode |
Gt | Yes | greater | Maps to “compare” in HloOpcode |
Le | Yes | less, equal | Maps to “compare” in HloOpcode |
Lt | Yes | less | Maps to “compare” in HloOpcode |
Abs | Yes | abs | |
Ceil | Yes | ceil | |
Cos | Yes | cos | |
Exp | Yes | exp | |
Floor | Yes | floor | |
IsFinite | Yes | isfinite | |
Log | Yes | log | |
Not | Yes | logical_not | |
PopulationCount | Yes | May need to add new op into relay | |
Neg | Yes | negative | |
Sign | Yes | sign | |
Tanh | Yes | tanh | |
Fft | Yes | May need to add new op into relay | |
Gather | Yes | gather_nd | |
GetDimensionSize | No | ndarray_size | Needed only for XLA dynamic padder |
SetDimensionSize | No | Needed only for XLA dynamic padder | |
GetTupleElement | Yes | TupleGetItem | |
Infeed | No | not inference related | |
Iota | Yes | May need to add new op into relay | |
Map | No | Found no use in OSS TensorFlow | |
Pad | Yes | nn.pad | |
Recv | No | Cross device communication, not applicable to single-core inference | |
Reduce | Yes | Cannot support full flexibility because Relay doesn’t allow op to take function as argument. However, we can pattern match to support common cases. | |
ReducePrecision | Yes | cast, cast | |
ReduceWindow | Yes | Cannot support full flexibility because Relay doesn’t allow op to take function as argument. However, we can pattern match to support common cases. | |
ReduceWindowWithGeneralPadding | Yes | Cannot support full flexibility because Relay doesn’t allow op to take function as argument. However, we can pattern match to support common cases. | |
ReplicaId | No | Needed when data parallelism is involved | |
Reshape | Yes | reshape | |
Rev | Yes | reverse | |
RngNormal | No | Inference graphs should not need RNG ops | |
RngUniform | No | Inference graphs should not need RNG ops | |
Scatter | Yes | May need to add new op into relay | |
Select | Yes | select | |
SelectAndScatter | Yes | Same problem as Scatter, may need to add new op into relay | |
Send | No | Corresponds to “Send” and “SendDone” in HloOpcode, cross device communication | |
Slice | Yes | strided_slice | |
Sort | Yes | Argsort | Cannot support full flexibility because Relay doesn’t allow op to take function as argument. However, we can pattern match to support common cases. |
Transpose | Yes | transpose | |
TriangularSolve | Yes | Should be composable with existing Relay ops, semantics here (https://www.tensorflow.org/xla/operation_semantics#triangularsolve) | |
Tuple | Yes | tuple | |
While | Yes | Recursive calls | |
AddDependency | No | Only for XLA internal use | |
Atan2 | No | Not in XLA builder API | |
Clz | No | Not in XLA builder API | |
Compare | No | Not in XLA builder API, internal implementation for all comparisons, no need to support separately | |
Complex | No | Not in XLA builder API | |
Constant | No | Not in XLA builder API | |
Copy | No | Not in XLA builder API | |
CopyDone | No | Not in XLA builder API | |
CopyStart | No | Not in XLA builder API | |
Domain | No | Not in XLA builder API, only for partitioning computation, no need to support | |
Expm1 | No | exp, subtract | Not in XLA builder API |
Imag | No | Not in XLA builder API | |
Log1p | No | log, subtract | Not in XLA builder API |
Parameter | No | Not in XLA builder API, represents input to computation | |
PartitionId | No | Not in XLA builder API, only needed for multi-device computations | |
Power | No | power | Not in XLA builder API |
RngGetAndUpdateState | No | Not in XLA builder API | |
RoundNearestAfz | No | Not in XLA builder API | |
Rsqrt | No | rsqrt | Not in XLA builder API |
ShiftLeft | Yes | Not in XLA builder API | |
ShiftRightArithmetic | Yes | Not in XLA builder API | |
ShiftRightLogical | Yes | Not in XLA builder API | |
Sin | No | sin | Not in and not used by XLA builder API |
Sqrt | No | sqrt | Not in XLA builder API |
Trace | No | Not in XLA builder API, only for profiling and tracing | |
TupleSelect | No | Not in XLA builder API | |
Xor | Yes | xor | Not in XLA builder API |
We can see from the table that:
- Most of operations have a straightforward one to one mapping
- A few operations require decomposition into several existing Relay ops, like Cholesky, TriangularSolve
- 4 HLO operations require addition of new Relay Ops
- 6 operations can be partially supported, namely convolutions (Relay supports rank <= 5), reduce and sort ops (due to lack of fully-flexible function as op argument support). However most common use cases are covered, like low-rank tensor convolution, ReduceMax, ReduceMin, ReduceMean, ReduceSum etc.
Overall, even though HLO→Relay Coverage isn’t perfect, I believe it is enough for all us to cover all but the most extreme cases, like rank>5 convolution, reduction ops other than Min/Max/Sub/Mean etc.
Summary
We believe that TF→HLO→Relay is a good alternative path to address model/op coverage that minimizes amount of effort while making TVM a first-class TensorFlow backend . Additionally, same work can be reused to acquire coverage for PyTorch, JAX and Julia.