Hi all, We have completed a workable draft of bfloat16 (bf16) in TVM.
We add bfloat16 as a new type named “bf16” in the frontend.
- Use int16 as the storage type
- Add legalization to enable computations on bf16
- Add runtime frontend support (e.g. allow converting numpy’s uint16 array to bf16 NDArray)
Motivation
bfloat16
is a 16-bit float-point data type. You can easily get a bfloat16 by truncating a fp32 number getting the higher-ordered 16 bits. bfloat16 has lower memory consumption and is more friendly to memory-bound applications. It also requires no special hardware instructions, as we can lower the computation on bf16 to casting to fp32 and then using fp32 to do the computations. Thus, we bring bfloat16 datatype in TVM.
Details on legalization
Since most of the HW has no native support for computation on bf16, we added a pass BF16Legalization
to use fp32 computing bf16 data. It adds cast_to_fp32()
before each Op involing bf16 operands, and use Ops of fp32 to compute. Finally, it adds a ‘cast_to_bf16()’ after each Op that is altered. e.g.
add(a,b)
=> cast16(add(cast32(a), cast32(b)))
We call this phase as “BF16Promotion”. It is a sub-pass of BF16Legalization
pass.
We note that this will add redundant casting. e.g.
add(a, neg(b))
=> cast16(add(cast32(a), cast32(cast16(neg(cast32(b)))))
The pattern cast32(cast16(some_fp32_value))
can be simplified to some_fp32_value
.
Thus, we add an optimization pass after “BF16Promotion” in BF16Legalization
pass, which eliminates redundant casts.
After BF16Legalization
pass, there will be no bf16 related computation in the AST, except casting between fp32 and bf16, bf16 value comparasion and assignment.
Casting between fp32 and bf16
We follow PyTorch’s bf16 casting implementation.
Pull request
Design choices in legalization
Please view @tqchen 's post below.