Cuda bfloat16 support

I have implemented it in https://github.com/apache/tvm/pull/7014, just follow how ‘half’ is implemented in cuda.

However, I have some questions:

  • For ‘bfloat16’ is not a built-in dtype in numpy, we cannot use ‘asnumpy’ to get a representation. Any proper way to do this? Cast to float?
  • use a pass config to disable ‘BF16Legalize’ when we have a gpu with compute arch >= sm80?
1 Like