What is FoldScaleAxis optimization

Hi, I’m newer to TVM.

I’m exploring the FoldScaleAxis but I don’t even understand what it is.

Is there a conceptual example of this optimization to catch the need of this and how it works.

Thank you.

For example, this enables pre-multiplying scaling in batch norm against its preceding conv2d weight. The combination of SimplifyInference, FoldScaleAxis, and FoldConstant completely eliminates batch norm from an inference graph.

1 Like

Thank you for the explanation. As far as I understand, it seems that operator fusion. In the aforementioned example, multiply and addition ops of batch norm can be merged to preceding conv2d op based on mathematical property.

I just now think FoldScaleAxis is one of ways to perform operator fusion. Why is FoldScaleAxis needed along with FuseOps? (I understood the example but I am confused with operator fusion)

FuseOp can only combine multiple ops into a single kernel but can’t not completely eliminate the computations. FoldScaleAxis is a kind of rewrite so that the transformed weight (the original weight + multiply / addition) can be pre-computed during compile time.

1 Like

Now I understand the concept of FoldScaleAxis. Thank you all for reply…!