Relay Automatic FP16 Downcasting

There are few discussions around Automatic FP16 Downcasting but there is no implmentation yet. We would like to propose a RFC for this topic.

The idea is pretty straightforward. For a fp32 model, we need to downcast the input and all the paramaters to fp16. After this, the output would also be fp16 so we need to upcast the output to fp32 in the end.

I already have a draft PR for this here. Currently the pass is written in python. We have tested it with several models in gluon model zoo and the accuracy looks good so far. The next step is to expand the model coverage by handling some edge cases and re-implement in with C++.

We are planning to reuse the exsisting auto-quantization pass to downcast to FP16 (or bfloat etc). Investigate if we need target-aware transformation infra, or can it be represented using the QConfig.

The proposal is discussed with @ janimesh.

The work will be sperated into 2 stages. The first stage will focus on the accuracy and the second stage will focus on the performance.

Stage 1: Framework Coverage + Model Coverage + Accuracy

Focus on the accuracy of the downcasted fp16 graph. Need to expand the model coverage with multiple model zoos. Currently we have tested with Resnet, vgg, mobilnet in gluon model zoo. We are planning to expand the coverage to GluonCV object detection model and the image classifier models in tensorflow model zoo. Here, we will use Intel machines to get the accuracy numbers. The goal is to check the robustness of downcast pass.

Stage 2: Performance Improvment for Nvidia GPU

Focus on the performance. We are targeting CUDA on GPU and ARM CPU with float16 support. Currently we are experiencing some errors during cude codegen as in [ERROR]CUDA compilation error.
In future, we can target ARM devices that have native FP16 support.

Any comments or thoughts are welcome :slight_smile:

@vinx13 @tqchen @zhiics @yzhliu @thierry

For CUDA backend, most operators don’t have overload for fp16, which means we need to implement CUDA specific codegen rules for fp16

I have a simple solution to the second problem. The root cause is that cuda fp16 computing uses “cuda_fp16.h”,
It does not support operations with volatile keywords, so for “add” we just need to add the following code

decl_stream << "__device__ half operator+(const volatile __half &a,  const volatile __half &b)\n"
"{\n  return __hadd(a, b);\n}\n";

The same is true for other operators, The function returned can be referred to file “cuda_fp16.h”.
Of course, if there is a better solution, I would like to be able to tell me. @xyzhou

I tested the 16 and 32 speeds in the current TVM and found that the reasoning speed of 16 was indeed faster.I want to get a better result, but I don’t know what other code or ideas are available for reference. Can you help me?thank you

A possible optimization for fp16 on cuda is to use intrinsic https://devblogs.nvidia.com/mixed-precision-programming-cuda-8/

Very glad to see this; I have a WIP branch on a 16bit instantiation of VTA; it would be great to have down-casting support to target the VTA design more easily.

Thanks @xyzhou This will be very useful. fyi ARM also supports fp16 and we’re pushing an automatic codegen for fp16 FMA in LLVM: http://lists.llvm.org/pipermail/llvm-dev/2019-September/134946.html

1 Like

This is related to automatic quantization and likely needs to be discussed more carefully before we send in a code.

Let us remember that simply changing data types to fp16 won’t work, as we will need mixed precisions to accumulate in fp32. It is also hw dependent, if the ALU of hardware does not have fp16, we want to default most ops to fp32 and only use fp16 for conv2d

Eventually we hope to achieve two things:

  • Use the same pass for most mixed precision float variants(bfloat, posit, fp16)
  • Have pass to do some re-adjustments and choose the precisions(like the quantization pass)

Would be great if @vinx13 can help manage this RFC and @ziheng can help as well.

@tqchen Trying to understand.

Let us remember that simply changing data types to fp16 won’t work, as we will need mixed precisions to accumulate in fp32.

Is it true that we have to accumulate in FP32 for something like conv2d?

I would assume that tanh or exp might be bad with FP16 and for them, we might want to keep FP32 even when there is FP16 support in HW.

It is also hw dependent, if the ALU of hardware does not have fp16, we want to default most ops to fp32 and only use fp16 for conv2d

Why do we even want to use FP16 for conv2d, if there is no support for FP16? For example, Intel machines do not support FP16 computation. Do we ever want to keep the input and data to be FP16 there?

Both of which depends on the instruction set being supported by the HW. Many hardware may only have restricted set of instructions that can be performed on fp16(e.g. tensor-core op) while do not support most other operations(that was what i mean by ALU). In such cases, we can only offload a limited set of ops into fp16 computation. This is quite typical for new accelerators.

Mixed precision computation is quite common for the low precision real numbers to avoid accumulation overflow. For example, Nvidia device support fp16 -> accumulate fp32 instruction. While it might be OK to accumulate in fp16 itself. we won’t get the corresponding hw acceleration.

To summarize, due to limited instructions supported by the HW, we need to convert the model to fit into native instructions. For low precision real numbers, that might likely needs mixed-precision and choice of certain dtypes by op.

Thanks @tqchen for the explanation.

@tqchen Thinking more about the problem and whats TOPI vs Relay division. For Nvidia mixed compute precision, does it make sense to have conv2d which have FP16 inputs and outputs. But, TOPI can choose to use FP32 accumulation. In this case, the hardware specific attributes are hidden in TOPI and don’t show up in Relay graph.

I think it is good to align with quantization that put the output type of conv2d in Relay (followed by cast ops). Hiding the details in TOPI can make different platforms implicitly inconsistent

1 Like

I see. Makes sense.

@vinx13 Are you thinking that we should modify the current automatic quantization pass and include FP16 support as well? (Personally, I think that FP16 pass might be much simpler and having a separate pass might not be a bad idea).

Also, do we take any target-dependent steps in the quantize pass today? If not, we need a way to perform target-dependent transformation inside quantize pass.

The difference between fp16 and quantization is that fp16 doesn’t need quantize/requantize operations. But it is always good to try reusing the code.
Quantization pass today doesn’t take target-dependent steps, the difference from targets are mainly reflected in the config quantization (dtype, nbits, …)

I see. So, can we scope the problem then? We can target this problem

Input DNN - MxNet/TF Resnet FP32 network
Targets - ARM instance having FP16 support, Nvidia supporting mixed computation
Code structure - Modify the quantization pass to downcast to FP16 (or bfloat etc). Investigate if we need target-aware transformation infra, or can it be represented using the QConfig.

Is that good enough scope for the first trial of this pass? Once we are happy with this, we can then expand model and target coverage.

Looks good. Another thing is to test model accuracy under fp16 (no mixed precision) to see which part needed for mixed precision

Thanks @vinx13. @xyzhou Can you please update the first post to reflect the changes?

@vinx13 can we refer to the four lists provided by Tensorflow?Details can be found here:

Yes you can refer to it