My impression is actually opposite to yours. My impression is if you return Undef or incompatible layout, then layout_transform will be inserted to guarantee the correctness. For example:
conv2d(NCHW, OIHW) -(NCHW)-> transpose(axis=[0, 2, 3, 1]) -(NHWC)->
The output layout of conv2d
is NCHW, and the output of transpose
transposes to NHWC. If we now want to convert conv2d to NHWC, then we have
conv2d(NHWC, HWIO) -(NHWC)-> transpose(?) -->
In this case, if transpose
has no InferCorrectLayout
registered, or InferCorrectLayout
returns Undef
, then the ConvertLayout pass will keep the original transpose
op and insert layout_transform
to keep NCHW in and NHWC out:
conv2d(NHWC, HWIO) -(NHWC)-> layout_transform -(NCHW)-> transpose(axis=[0, 2, 3, 1]) - (NHWC) ->
On the other hand, if transpose
's InferCorrectLayout
is powerful enough, it will adjust its attributes to accept the new input layout and return the new layout:
conv2d(NHWC, HWIO) -(NHWC)-> transpose(axis=[0, 1, 2, 3]) - (NHWC) ->
In this case, layout_transfrom
won’t be inserted.