[RFC][Relay] FP32 -> FP16 Model Support

This RFC discusses post-training quantization of FP32 models to an FP16 form. We do not focus on specific codegen but rather the relay representation of the models.

Doing this transformation is useful for reducing model size as it halves the expected size of the weights. In addition to potential improvements in memory bandwidth, many hardware platforms which support FP16 have theoretically higher throughput for FP16 operations compared to FP32. However, using FP16 operations often requires casting from FP32 → FP16 or vice versa which introduces some overhead. Therefore, we must have some care about when and where we insert casts.

Design Overview

The main aspect of the design is lifted from TensorFlow’s automatic mixed precision training. A relevant slidedeck is here (https://on-demand.gputechconf.com/gtcdc/2019/pdf/dc91247-automatic-mixed-precision-in-tensorflow.pdf) with a focus on slides 12 - 20.

The main idea is to have a “Green List” and “Gray List” of operations which might benefit from FP16 versions of the operations.

  • Green List operations are much faster in an FP16 form such that it is almost always preferable to transform the operation, even if we must cast the inputs.
  • Gray List operations either have no speedup from FP16 or have speedups too small to justify overhead from casting.

Operations not in either list are assumed to either not support FP16 or have good reasons not to use their FP16 versions. For example, some operations like exp need FP32 to maintain numerical stability. From these lists we can determine which parts of our computational graph are in the FP16 domain and which are in the FP32 domain. From here, it becomes clear which parts of the graph need casting operations.

Finally, there needs to be some control over the output data types for converted operations. Some FP16 operations might accumulate results into FP32 for numerical reasons and some might produce an FP16 number. In the rewrite of the graph, we will provide some control over this variable.

Pass Example

The main meat of this RFC is the creation of a new pass which, given the ability to map operations into the aforementioned green and gray lists, can insert the proper casts and operation transformations. An example of this can be seen in the linked slide deck. In this example, all quantized functions return FP16. Green and gray colors correspond to our lists:

The final step is to replace all green colored operations with FP16 versions.

What Should be Extensible?

We want this pass to be generally usable.

To accomplish this goal, we want to expose the ability to modify the green and gray lists for operations. Furthermore, we want to be able to have granular control over what we put in a particular list. For example, it might make sense for a hardware platform to place a convolution into the green list only when the input tensors and weights are of sufficient size. For the purposes of the initial implementation however, we will do something simpler with classifying Relay nodes.

Finally, we will want the user to have a little control over how an operation can be rewritten into an FP16 graph. Here, we assume a 1-1 mapping between operations. E.g. an FP32 conv. in Relay is just the exact same conv. Node but with FP16 inputs. The one knob we will expose for certain operations is controlling the accumulation datatype. E.g. we might have an FP16 conv. which accumulates the solution into an FP32 buffer and one which accumulates into an FP16 buffer. Note this information is needed in the pass when propagating colors through our graph.

Future Work

Write code

Flesh out interface into pass

Test codegen works as expected for CUDA and x86 at the least.

Support for bfloat16

12 Likes

Thanks for this much needed contribution!

Can you elaborate on the design you imagine for

there needs to be some control over the output data types for converted operations. Some FP16 operations might accumulate results into FP32 for numerical reasons and some might produce an FP16 number. In the rewrite of the graph, we will provide some control over this variable.

In some edge use cases, it is desirable for all parameters to be stored as fp16 to limit storage footprint. In that context, take the following graph as an example,

conv2d -> multiply -> bias_add -> relu -> max_pool

Greenlist{conv2d, max_pool}
Graylist{elemwise}

If the conv2d should accumulate in fp32, but the consecutive elemwise operators should run in fp16, how will a user express this? In this case I would expect final graph to be,

[fp16] → conv2d → [fp32] → cast(fp16) → multiply → bias_add → relu → max_pool

fused_conv2d_cast_multiply_bias_add_relu → max_pool

An alternative option would be to add an accumulator_dtype separate from output_dtype to various operators and rewrite based on that field. Both can work but I’d like to hear more on how you envision to do this with the mixed precision transform in the above context. Thanks!

Hey Chris,

The two extensible bits will be done through user defined callable functions.

For the green list/gray list/red list situation, we have the user define a function which given a Call node, returns the color of the operation. For the initial implementation we will just do a naive solution like placing all conv2d’s in the green list, all elementwise in the graylist, etc.

For the accumulation datatype, I imagine a user defined function which given a Call node, returns the accumulation datatype and the output datatype of the operation. The accumulation datatype is self explanatory and confusingly maps to the existing “output_dtype” field in existing relay ops like conv and dense. Our “new” output datatype meanwhile for example tells what precision other operations will ingest the results of the operation at:

weight (fp16 or fp32), data (fp16 or fp32) → conv2d (accumulation_dtype) → cast(output_dtype).

If the accumulation_dtype == output_dtype then we don’t need the cast.

Finally, to answer your question, in the scenario given we would simply express conv2d as an operator with an accumulation_dtype of fp32 and an output_dtype of fp16. This should give the final graph listed (don’t know about the operator fusion part though to be honest, not sure if all the knobs on the fused operator are there, if not guess I have to do something about that too). In a sense we do have separate “accumulator_dtype” and “output_dtypes” then the user can define on a per-operation basis.

I hope that answers your question and I hope it is sufficient for most applications! For the default I am going to do something simple like define all operators which support accumulation datatypes separate from the input datatypes accumulate into fp32 but output into fp16. Otherwise, we assume it accumulates in fp16 and outputs fp16 (e.g. for elementwise operators).

There are downsides with this simplistic method. One downside is the only sort of analysis that is easy with this framework looking at the current operator. That is to say, it’s kind of cumbersome to look ahead and backward to make decisions. There are some other theoretical limitations in the graphs it can easily generates but I think it covers most reasonable scenarios! \

2 Likes

As I understand we do not have currently accumulation data type attribute for ops, right? We just seems need to introduce such one similar to dtype. The default value should match to dtype if it is defined and to type if dtype is not defined. The change of accumulator type is a matter of some additional logic this will define, for example, that in this certain model the third convolution should accumulate into fp32, for example to get good accuracy. If there is no such additional logic, the default behaviour (matching of accum to dtype and type) is good enough

1 Like

[Relay] [Pass] Add FP16 model conversion pass by AndrewZhaoLuo · Pull Request #8069 · apache/tvm · GitHub <— PR here

Hi @AndrewZhaoLuo this RFC is right on the edge of when we approved the new RFC process, but could you write up the RFC as a pull request for the new TVM RFCs repository?

1 Like

In accordance with new RFC process please move discussion here:

@AndrewZhaoLuo, When we enable mixed precision arithmetic, the output types of a model depends on the last operation in the relay IR, ie, if it convolution, it will be in float16 and for the case argmax it will be int32 and in the case of softmax it will be float32. So, is this expected? Or should the nodes that are marked as output nodes should be treated differently so that they return the expected specific out_type irrespective of the optimization(mixed precision pass) happening? Please let me know your thoughts on this.

Yeah the last operation will always be casted to the original output dtype. The input dtypes are also preserved. This was done to guarantee functions don’t have their types changed from the pass.

The original plan was the write another pass to fold these transformations in (which I think there is a PR somewhere languishing made by @anwang).

Hi @AndrewZhaoLuo, Thank you for your reply, But for most of models when we enable the mixed precision pass, we are getting out_type as float16, which should not be the case from what we discussed? could you please point towards more details regarding the pass written by @anwang

Hey @gayatripk1 that is interesting, i might need to go back and reread the pass. Perhaps constant folding afterwards allow this to occur.

As for the pass: https://github.com/apache/tvm/pull/9357

I think this is what you want? I can ping @anwang for this.

Hi @AndrewZhaoLuo,

Thank you for the informative reply.

In of the example code [snippet attached] of enabling mixed precision pass, saw the following passes used, could you please elaborate on if these passes need to be used for all the case when we work with mixed precision. I tested with and without enabling these passes on models and still getting a float16 output in both the cases. Also, tried disabling the constant folding pass [not a good idea as it will have performance degradation] just to check if that was causing this conversion, found the output type to be float16 in that case as well.

Also, any other passes need to be enabled/disabled for enabling mixed precision pass?

It will really good to know if PR #9357 is going to be merged.

You don’t need any of these passes. But if you are using BYOC, you need to be careful. I can expand on that if you like.

Hi @masahi, Thank you for your reply. Could you please expand that.

In BYOC which expects fp16 inputs and params, you do need to run FoldConstant etc after ToMixedPrecision.

After ToMixedPrecision, parameters are not converted to fp16 “yet”, instead your model will have many op.cast(weight_fp32, dtype=fp16). These casts are computed at compile time when you run relay.build(...), but if you need fp16 params before relay.build, you need to run FoldConstant.

Other passes you showed are not essential, they might or not improve performance but that’s all.

Hi @AndrewZhaoLuo, could you please look into this case and give more clarity regarding this mixed precision pass in the following context “for most of models when we enable the mixed precision pass, we are getting out_type as float16.” Also regarding the "The original plan was the write another pass to fold these transformations " by @anwang.

Hey @gayatripk1, probably won’t have any time soon to work on this. I’ve added it to https://github.com/apache/tvm/issues/8296. If you are interested perhaps you can make a contribution B)

@masahi

In my experiments, run only with FoldConstant would not eliminate the cast ops. Instead I need to run the above BindPass before convert the model to fp16. I am not sure if this help for the performace, since in my experiment the performance of fp32 and fp16 are almost same in cuda.