Continuing from the previous thread, it seems batch norm is always decomposed during build_module.optimize(). Is this expected? How can I preserve batch norm until codegen?
cc @comaniac
Before opt
fn (%data: Tensor[(1, 3, 224, 224), float32], %layer1_weight: Tensor[(16, 3, 3, 3), float32], %layer1_bn_gamma: Tensor[(16), float32], %layer1_bn_beta: Tensor[(16), float32], %layer1_bn_mean: Tensor[(16), float32], %layer1_bn_var: Tensor[(16), float32]) -> Tensor[(1, 16, 224, 224), float32] {
%3 = fn (%dnnl_input0: Tensor[(1, 3, 224, 224), float32], %dnnl_input1: Tensor[(16, 3, 3, 3), float32], %dnnl_input2: Tensor[(16), float32], %dnnl_input3: Tensor[(16), float32], %dnnl_input4: Tensor[(16), float32], %dnnl_input5: Tensor[(16), float32], Compiler="dnnl", ExternalSymbol="dnnl_0", Primitive=1) -> Tensor[(1, 16, 224, 224), float32] {
%0 = nn.conv2d(%dnnl_input0, %dnnl_input1, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
%1 = nn.batch_norm(%0, %dnnl_input2, %dnnl_input3, %dnnl_input4, %dnnl_input5) /* ty=(Tensor[(1, 16, 224, 224), float32], Tensor[(16), float32], Tensor[(16), float32]) */;
%2 = %1.0;
nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
};
%3(%data, %layer1_weight, %layer1_bn_gamma, %layer1_bn_beta, %layer1_bn_mean, %layer1_bn_var) /* ty=Tensor[(1, 16, 224, 224), float32] */
}
After
fn (%data: Tensor[(1, 3, 224, 224), float32], %layer1_weight: Tensor[(16, 3, 3, 3), float32], %layer1_bn_gamma: Tensor[(16), float32], %layer1_bn_beta: Tensor[(16), float32], %layer1_bn_mean: Tensor[(16), float32], %layer1_bn_var: Tensor[(16), float32]) -> Tensor[(1, 16, 224, 224), float32] {
%12 = fn (%dnnl_input0: Tensor[(1, 3, 224, 224), float32], %dnnl_input1: Tensor[(16, 3, 3, 3), float32], %dnnl_input2: Tensor[(16), float32], %dnnl_input3: Tensor[(16), float32], %dnnl_input4: Tensor[(16), float32], %dnnl_input5: Tensor[(16), float32], Compiler="dnnl", ExternalSymbol="dnnl_0", Primitive=1) -> Tensor[(1, 16, 224, 224), float32] {
%0 = nn.conv2d(%dnnl_input0, %dnnl_input1, padding=[1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
%1 = add(%dnnl_input5, 1e-05f /* ty=float32 */) /* ty=Tensor[(16), float32] */;
%2 = sqrt(%1) /* ty=Tensor[(16), float32] */;
%3 = divide(1f /* ty=float32 */, %2) /* ty=Tensor[(16), float32] */;
%4 = multiply(%3, %dnnl_input2) /* ty=Tensor[(16), float32] */;
%5 = expand_dims(%4, axis=1, num_newaxis=2) /* ty=Tensor[(16, 1, 1), float32] */;
%6 = multiply(%0, %5) /* ty=Tensor[(1, 16, 224, 224), float32] */;
%7 = negative(%dnnl_input4) /* ty=Tensor[(16), float32] */;
%8 = multiply(%7, %4) /* ty=Tensor[(16), float32] */;
%9 = add(%8, %dnnl_input3) /* ty=Tensor[(16), float32] */;
%10 = expand_dims(%9, axis=1, num_newaxis=2) /* ty=Tensor[(16, 1, 1), float32] */;
%11 = add(%6, %10) /* ty=Tensor[(1, 16, 224, 224), float32] */;
nn.relu(%11) /* ty=Tensor[(1, 16, 224, 224), float32] */
};
%12(%data, %layer1_weight, %layer1_bn_gamma, %layer1_bn_beta, %layer1_bn_mean, %layer1_bn_var) /* ty=Tensor[(1, 16, 224, 224), float32] */
}