Thanks for this much needed contribution!
Can you elaborate on the design you imagine for
there needs to be some control over the output data types for converted operations. Some FP16 operations might accumulate results into FP32 for numerical reasons and some might produce an FP16 number. In the rewrite of the graph, we will provide some control over this variable.
In some edge use cases, it is desirable for all parameters to be stored as fp16 to limit storage footprint. In that context, take the following graph as an example,
conv2d -> multiply -> bias_add -> relu -> max_pool
Greenlist{conv2d, max_pool}
Graylist{elemwise}
If the conv2d should accumulate in fp32, but the consecutive elemwise operators should run in fp16, how will a user express this? In this case I would expect final graph to be,
[fp16] → conv2d → [fp32] → cast(fp16) → multiply → bias_add → relu → max_pool
fused_conv2d_cast_multiply_bias_add_relu → max_pool
An alternative option would be to add an accumulator_dtype separate from output_dtype to various operators and rewrite based on that field. Both can work but I’d like to hear more on how you envision to do this with the mixed precision transform in the above context. Thanks!