[float16][BERT] Getting all NaN in output after transform the module with ToMixedPrecision

Hi, I noticed that in the mixed precision PR it is mentioned that BERT is tested and verified. However when I applied the ToMixedPrecision PASS on my fp32 mod, the inference result I am getting from the fp16_mod is full of nans. Have anyone ever bump into this problem?

if dtype == "float16":
    mod = InferType()(mod)
    fp16_mod = ToMixedPrecision(dtype)(mod)

@AndrewZhaoLuo @anijain2305 @masahi @lhutton1

Hey @mgeek. I have run into this, and it can be due to a lot of reasons unfortunately, from overflow/underflow to an error with order of operations.

One thing you can do is try to dump the values of the intermediate tensors and see where things start to diverge. It’s been a while since I’ve done this but an example of how to do this can be found here: https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/main/relay/graph_debugger_example.py.

Thanks for the reply! I see, and even weirder is that when I move everything unchanged to a new arm server, it somehow can produce results that make sense, however, the absolute error compared to ground truth is still considerably large.

Actually, I am thinking of trying to turn-on and turn-off the fp16 PASS for each operator and see which operator contributes to the final error the most. Could you tell me where to find the list that controls the so-called green/grey/red marker for ToMixedPrecision PASS?

https://github.com/apache/tvm/blob/main/python/tvm/relay/transform/mixed_precision.py#L34 Check our here for where the lists are defined.

You can change these directly or you can override them like so:

1 Like

Hi Andrew,

Thanks for the detailed info!

I tried to isolate single fp16 ops one by one and found that almost every operator (dense, batch_matmul, transpose, reshape, add, etc) contributes to the final error to some non-negligible extent. Does this ever occur to you?

However, if I run the native float16 mode with PyTorch there are barely any accuracy drops, based on the params it seems almost everything was cast to fp16 as well, any idea on the accuracy difference between PyTorch fp16 and TVM’s fp16?

Hi @mgeek, I have not seen this before, where accuracy degradation is very high compared to PyTorch do you have any other info on the platform you are running this model on?

It’s worth trying the output dtype of dense and bmm to fp32. Otherwise we do all accumulation in fp16, which may be affecting accuracy. I think PT uses fp32 accumulation everywhere.

Thanks for the advice @masahi! Doing accumulation in fp32 does help with the accuracy ( improves the absolute error compared to fp32 from 1e-2 to 1e-3) but it significantly affects the inference speed (even much slower than the pure fp32 mode).