I have implemented it in https://github.com/apache/tvm/pull/7014, just follow how ‘half’ is implemented in cuda.
However, I have some questions:
- For ‘bfloat16’ is not a built-in dtype in numpy, we cannot use ‘asnumpy’ to get a representation. Any proper way to do this? Cast to float?
- use a pass config to disable ‘BF16Legalize’ when we have a gpu with compute arch >= sm80?