Heavily-layout sensitive operator not surounded by layout_transform operators

Hello everyone, While using TVM I noticed that that when trying to use ConvertLayout pass lightly-layout sensitive operators are included between LayoutTransform operators only if they are after a highly sensitive one. In my use case I have something like this:

  Input
    |
    v
max_pool2d
    |
    v
  conv2d

I’ve tried to make max_pool2d highly sensitive operator in order to be able to pass it’s desired layout to ConvertLayout pass. Here is an sample code that I wrote according to this article.

@reg.register_convert_op_layout("nn.max_pool2d")
def  convert_max_pool2d(attrs, inputs, tinfos, desired_layouts):

from tvm import relay
data, = inputs

layout_config = LayoutConfig.current
if layout_config is  not  None:
    skip_layer = layout_config.check_skip()
    if skip_layer:
        return relay.nn.max_pool2d(data, **attrs)

new_attrs = dict(attrs)
assert  len(desired_layouts) == 1, "One desired layout is expected for nn.max_pool2d"
desired_layout, = map(str, desired_layouts)
assert desired_layout != "default", "Data layout cannot be default"
new_attrs["layout"] = desired_layout
return relay.nn.max_pool2d(data, **new_attrs)

It should just replace layout with the one passed. This is working, but layout_transform operators are not inserted before max_pool2d. Here are 3 tests that I run and their results:

desired_layouts = {'nn.conv2d': ["NCHW", "HWIO"],
                   'nn.max_pool2d': ["NCHW"]}
  1. Only max_pool2d:
    data = relay.var("data", shape=(1, 4, 4, 9))
    out = relay.nn.max_pool2d(data, pool_size=(2,2), layout="NHWC")
    f = relay.Function(relay.analysis.free_vars(out), out)
    mod = tvm.IRModule.from_expr(f)

    seq = tvm.ir.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
    with tvm.transform.PassContext(opt_level=10):
        mod = seq(mod)
    print(mod)

Result:

def @main(%data: Tensor[(1, 4, 4, 9), float32]) -> Tensor[(1, 3, 3, 9), float32] {
  nn.max_pool2d(%data, pool_size=[2, 2], padding=[0, 0, 0, 0], layout="NHWC") /* ty=Tensor[(1, 3, 3, 9), float32] */
}
  1. conv2d → max_pool2d:
    data = relay.var("data", shape=(1, 4, 4, 9))
    weight = np.random.uniform(-10, 10, (3, 3, 9, 3)).astype('float32')
    w = relay.const(weight)

    conv = relay.nn.conv2d(data, w,
                          kernel_size=(3, 3),
                          padding=(1, 1),
                          channels=3,
                          groups=1,
                          data_layout="NHWC",
                          kernel_layout="HWIO")
    out = relay.nn.max_pool2d(conv, pool_size=(2,2), layout="NHWC")
    f = relay.Function(relay.analysis.free_vars(out), out)
    mod = tvm.IRModule.from_expr(f)

    seq = tvm.ir.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
    with tvm.transform.PassContext(opt_level=10):
        mod = seq(mod)
    print(mod)

Result:

def @main(%data: Tensor[(1, 4, 4, 9), float32]) -> Tensor[(1, 3, 3, 3), float32] {
  %0 = layout_transform(%data, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 9, 4, 4), float32] */;
  %1 = nn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 9, 3), float32] */, padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], kernel_layout="HWIO") /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %2 = nn.max_pool2d(%1, pool_size=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 3, 3), float32] */;
  layout_transform(%2, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 3, 3, 3), float32] */
}
  1. max_pool2d → conv2d:
    data = relay.var("data", shape=(1, 4, 4, 9))
    weight = np.random.uniform(-10, 10, (3, 3, 9, 3)).astype('float32')
    w = relay.const(weight)

    maxpool = relay.nn.max_pool2d(data, pool_size=(2,2), layout="NHWC")
    out = relay.nn.conv2d(maxpool, w,
                          kernel_size=(3, 3),
                          padding=(1, 1),
                          channels=3,
                          groups=1,
                          data_layout="NHWC",
                          kernel_layout="HWIO")
    f = relay.Function(relay.analysis.free_vars(out), out)
    mod = tvm.IRModule.from_expr(f)

    seq = tvm.ir.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
    with tvm.transform.PassContext(opt_level=10):
        mod = seq(mod)
    print(mod)

Result:

def @main(%data: Tensor[(1, 4, 4, 9), float32]) -> Tensor[(1, 3, 3, 3), float32] {
  %0 = nn.max_pool2d(%data, pool_size=[2, 2], padding=[0, 0, 0, 0], layout="NHWC") /* ty=Tensor[(1, 3, 3, 9), float32] */;
  %1 = layout_transform(%0, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 9, 3, 3), float32] */;
  %2 = nn.conv2d(%1, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 9, 3), float32] */, padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], kernel_layout="HWIO") /* ty=Tensor[(1, 3, 3, 3), float32] */;
  layout_transform(%2, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 3, 3, 3), float32] */
}

As it can be seen here, max_pool2d is positioned between layout_transform operators only when it is positioned after conv2d operator which was the case before it was heavily-layout sensitive operator. What can I do for it to be included between layout_transforms generated by ConvertLayout pass?

Thanks in advance.

max_pool2d already has InferCorrectLayouts, so it’s possible that your ConvertOpLayout is not effective. You may try to remove the InferCorrectLayouts in src/relay/nn/pooling.cc to verify if it is the case.

Also cc @anijain2305

I agree this seems like a bug. One thing I would try is to try layout transform of NCHW to NHWC and see if the same bug appears.

Thank you for your responses. @comaniac I can confirm that convert_max_pool2d was being called even with InferCorrectLayouts still on. After removing it from polling.cc ConvertLayout pass is failing with following error:

The Relay type checker is unable to show the following types match.
In particular dimension 1 conflicts: 8 does not match 9.dimension 3 conflicts: 4 does not match 3.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(1, 9, 3, 3), float32]` does not match `Tensor[(1, 8, 3, 4), float32]`

In my example I used input with shape (1, 9, 4, 4) and pool_size is (2, 2). I checked PoolInferCorrectLayout and it looks like it’s only changing the layout to new one, the same way that my convert_max_pool2d is doing it. So maybe there are other functions that are doing something in background that I’m not aware of and as effect of them not being called ConvertLayout pass is producing wrong shape. I compared this to the conv2d code for doing this and it doesn’t look that that code is doing anything extra. Am I missing something?

@masahi I tried swapping NCHW with NHWC and produced results are same as before:

  1. Only max_pool2d:
def @main(%data: Tensor[(1, 9, 4, 4), float32]) -> Tensor[(1, 9, 3, 3), float32] {
  nn.max_pool2d(%data, pool_size=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 9, 3, 3), float32] */
}
  1. conv2d → max_pool2d:
def @main(%data: Tensor[(1, 9, 4, 4), float32]) -> Tensor[(1, 1, 3, 3), float32] {
  %0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 4, 4, 9), float32] */;
  %1 = nn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 9, 1), float32] */, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 4, 4, 1), float32] */;
  %2 = nn.max_pool2d(%1, pool_size=[2, 2], padding=[0, 0, 0, 0], layout="NHWC") /* ty=Tensor[(1, 3, 3, 1), float32] */;
  layout_transform(%2, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 1, 3, 3), float32] */
}
  1. max_pool2d → conv2d:
def @main(%data: Tensor[(1, 9, 4, 4), float32]) -> Tensor[(1, 1, 4, 4), float32] {
  %0 = nn.max_pool2d(%data, pool_size=[1, 1], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 9, 4, 4), float32] */;
  %1 = layout_transform(%0, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 4, 4, 9), float32] */;
  %2 = nn.conv2d(%1, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 9, 1), float32] */, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO") /* ty=Tensor[(1, 4, 4, 1), float32] */;
  layout_transform(%2, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 1, 4, 4), float32] */
}

Hi @masahi @comaniac, appreciate your feedback. Do you have further comments on this thread?

1 Like