InferCorrectLayout

My impression is actually opposite to yours. My impression is if you return Undef or incompatible layout, then layout_transform will be inserted to guarantee the correctness. For example:

conv2d(NCHW, OIHW) -(NCHW)-> transpose(axis=[0, 2, 3, 1]) -(NHWC)->

The output layout of conv2d is NCHW, and the output of transpose transposes to NHWC. If we now want to convert conv2d to NHWC, then we have

conv2d(NHWC, HWIO)  -(NHWC)-> transpose(?)  -->

In this case, if transpose has no InferCorrectLayout registered, or InferCorrectLayout returns Undef, then the ConvertLayout pass will keep the original transpose op and insert layout_transform to keep NCHW in and NHWC out:

conv2d(NHWC, HWIO) -(NHWC)-> layout_transform  -(NCHW)-> transpose(axis=[0, 2, 3, 1])  - (NHWC) ->

On the other hand, if transpose's InferCorrectLayout is powerful enough, it will adjust its attributes to accept the new input layout and return the new layout:

conv2d(NHWC, HWIO)  -(NHWC)-> transpose(axis=[0, 1, 2, 3]) - (NHWC) ->

In this case, layout_transfrom won’t be inserted.

1 Like