Hi all, We have completed a workable draft of bfloat16 (bf16) in TVM.
We add bfloat16 as a new type named “bf16” in the frontend.
- Use int16 as the storage type
- Add legalization to enable computations on bf16
- Add runtime frontend support (e.g. allow converting numpy’s uint16 array to bf16 NDArray)
bfloat16 is a 16-bit float-point data type. You can easily get a bfloat16 by truncating a fp32 number getting the higher-ordered 16 bits. bfloat16 has lower memory consumption and is more friendly to memory-bound applications. It also requires no special hardware instructions, as we can lower the computation on bf16 to casting to fp32 and then using fp32 to do the computations. Thus, we bring bfloat16 datatype in TVM.
Details on legalization
Since most of the HW has no native support for computation on bf16, we added a pass
BF16Legalization to use fp32 computing bf16 data. It adds
cast_to_fp32() before each Op involing bf16 operands, and use Ops of fp32 to compute. Finally, it adds a ‘cast_to_bf16()’ after each Op that is altered. e.g.
We call this phase as “BF16Promotion”. It is a sub-pass of
We note that this will add redundant casting. e.g.
add(a, neg(b)) =>
cast32(cast16(some_fp32_value)) can be simplified to
Thus, we add an optimization pass after “BF16Promotion” in
BF16Legalization pass, which eliminates redundant casts.
BF16Legalization pass, there will be no bf16 related computation in the AST, except casting between fp32 and bf16, bf16 value comparasion and assignment.
Casting between fp32 and bf16
We follow PyTorch’s bf16 casting implementation.
Design choices in legalization
Please view @tqchen 's post below.