MergeComposite and Pad with multiple consumers

Hi, sorry for a huge text

Have studied the information on the links, but have not found a solution [Relay] Improved graph partitioning algorithm, [BYOC] Use pattern language to create composite functions

Q1: MergeComposite on a Pad with multiple consumers behavior

There is a network in which, among others, there is such a structure of operators:

             pad
           /     \

DepthwiseConv2D DepthwiseConv2D

Below is a small test example to show how it happens. (In the test example, the final qnn.add operator is presented only to have one output. In my network I can not use the diamond-like pattern because I have only branching without union):

Initial Relay:

def @main(%x: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] span=x:0:0 */, output_tensor_names=["Identity"]) -> Tensor[(1, 215, 228, 3), int8] {
  %0 = nn.pad(%x, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */;
  %1 = qnn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.00392157f /* ty=float32 span=Conv2D1:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
  %2 = nn.bias_add(%1, meta[relay.Constant][2] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
  %3 = qnn.conv2d(%0, meta[relay.Constant][4] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.00392157f /* ty=float32 span=Conv2D_12:0:0 */, meta[relay.Constant][5] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
  %4 = nn.bias_add(%3, meta[relay.Constant][6] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
  %5 = qnn.requantize(%2, meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.0258486f /* ty=float32 span=Conv2D1:0:0 */, -128 /* ty=int32 span=Conv2D1:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D1:0:0 */;
  %6 = qnn.requantize(%4, meta[relay.Constant][7] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.0217672f /* ty=float32 span=Conv2D_12:0:0 */, -128 /* ty=int32 span=Conv2D_12:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D_12:0:0 */;
  qnn.add(%5, %6, 0.0258486f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.0217672f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.046443f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */) /* ty=Tensor[(1, 215, 228, 3), int8] span=Identity:0:0 */
}

Relay after MergeComposite pass:

def @main(%x: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] span=x:0:0 */, output_tensor_names=["Identity"]) -> Tensor[(1, 215, 228, 3), int8] {
  %0 = fn (%FunctionVar_0_01: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] */, PartitionedFromPattern="nn.pad_", Composite="ethos-u.pad2d") -> Tensor[(1, 215, 228, 2), int8] {
    nn.pad(%FunctionVar_0_01, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */
  } /* ty=fn (Tensor[(1, 214, 227, 2), int8]) -> Tensor[(1, 215, 228, 2), int8] */;
  %1 = %0(%x) /* ty=Tensor[(1, 215, 228, 2), int8] */;
  %2 = qnn.conv2d(%1, meta[relay.Constant][0] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.00392157f /* ty=float32 span=Conv2D1:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
  %3 = nn.bias_add(%2, meta[relay.Constant][2] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
  %4 = qnn.conv2d(%1, meta[relay.Constant][4] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.00392157f /* ty=float32 span=Conv2D_12:0:0 */, meta[relay.Constant][5] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
  %5 = nn.bias_add(%4, meta[relay.Constant][6] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
  %6 = qnn.requantize(%3, meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.0258486f /* ty=float32 span=Conv2D1:0:0 */, -128 /* ty=int32 span=Conv2D1:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D1:0:0 */;
  %7 = qnn.requantize(%5, meta[relay.Constant][7] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.0217672f /* ty=float32 span=Conv2D_12:0:0 */, -128 /* ty=int32 span=Conv2D_12:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D_12:0:0 */;
  %8 = fn (%FunctionVar_0_0: Tensor[(1, 215, 228, 3), int8] /* ty=Tensor[(1, 215, 228, 3), int8] */, %FunctionVar_0_1: Tensor[(1, 215, 228, 3), int8] /* ty=Tensor[(1, 215, 228, 3), int8] */, PartitionedFromPattern="qnn.add_", Composite="ethos-u.add") -> Tensor[(1, 215, 228, 3), int8] {
    qnn.add(%FunctionVar_0_0, %FunctionVar_0_1, 0.0258486f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.0217672f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.046443f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */) /* ty=Tensor[(1, 215, 228, 3), int8] span=Identity:0:0 */
  } /* ty=fn (Tensor[(1, 215, 228, 3), int8], Tensor[(1, 215, 228, 3), int8]) -> Tensor[(1, 215, 228, 3), int8] */;
  %8(%6, %7) /* ty=Tensor[(1, 215, 228, 3), int8] */
}

As you see, only the nn.pad operator is wrapped in a function corresponding to the pattern and will be offloaded to the NPU, and qnn.conv2d will not be offloaded.

pattern_table:

    (
        QnnConv2DParams.composite_name,
        qnn_conv2d_pattern(),
        lambda pat: QnnConv2DParams(pat).is_valid(),
    ),
    (
        QnnDepthwiseConv2DParams.composite_name,
        qnn_depthwise_conv2d_pattern(),
        lambda pat: QnnDepthwiseConv2DParams(pat).is_valid(),
    ),
	...
	(
        PadParams.composite_name,
        pad_pattern(),
        lambda pat: PadParams(pat).is_valid(),
    ),
	...

So far I have found only one temporary solution: split qnn_depthwise_conv2d_pattern into (1)pad + depthwise_conv2d (2)depthwise_conv2d and put them in a pattern table with the same order instead of the existing qnn_depthwise_conv2d_pattern.

Another option is a pass that adds pads to the Relay, so that each depthwise_conv2d has its “own” pad. I’m starting to get a handle on this, but I’m having a problem with the SRAM at runtime in some cases.

Q2 What I’m missing when copying the Pad operator?

(1) - Successful minimal case - the inference goes all the way to the end without errors

Relay after a pass which duplicates Pads:

def @main(%x: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] span=x:0:0 */, output_tensor_names=["Identity"]) -> Tensor[(1, 215, 228, 3), int8] {
  %0 = nn.pad(%x, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */;
  %1 = qnn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.00392157f /* ty=float32 span=Conv2D1:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
  %2 = nn.bias_add(%1, meta[relay.Constant][2] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
  %3 = nn.pad(%x, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */;
  %4 = qnn.conv2d(%3, meta[relay.Constant][4] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.00392157f /* ty=float32 span=Conv2D_12:0:0 */, meta[relay.Constant][5] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
  %5 = nn.bias_add(%4, meta[relay.Constant][6] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
  %6 = qnn.requantize(%2, meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.0258486f /* ty=float32 span=Conv2D1:0:0 */, -128 /* ty=int32 span=Conv2D1:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D1:0:0 */;
  %7 = qnn.requantize(%5, meta[relay.Constant][7] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.0217672f /* ty=float32 span=Conv2D_12:0:0 */, -128 /* ty=int32 span=Conv2D_12:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D_12:0:0 */;
  qnn.add(%6, %7, 0.0258486f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.0217672f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.046443f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */) /* ty=Tensor[(1, 215, 228, 3), int8] span=Identity:0:0 */
}

Relay after MergeComposite (pad+conv2d is united to a one function, ok):

def @main(%x: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] span=x:0:0 */, output_tensor_names=["Identity"]) -> Tensor[(1, 215, 228, 3), int8] {
  %3 = fn (%FunctionVar_1_0: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] */, PartitionedFromPattern="nn.pad_qnn.conv2d_nn.bias_add_qnn.requantize_", Composite="ethos-u.qnn_conv2d") -> Tensor[(1, 215, 228, 3), int8] {
    %0 = nn.pad(%FunctionVar_1_0, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */;
    %1 = qnn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.00392157f /* ty=float32 span=Conv2D1:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
    %2 = nn.bias_add(%1, meta[relay.Constant][2] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D1:0:0 */;
    qnn.requantize(%2, meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.0258486f /* ty=float32 span=Conv2D1:0:0 */, -128 /* ty=int32 span=Conv2D1:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D1:0:0 */
  } /* ty=fn (Tensor[(1, 214, 227, 2), int8]) -> Tensor[(1, 215, 228, 3), int8] */;
  %7 = fn (%FunctionVar_0_01: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] */, PartitionedFromPattern="nn.pad_qnn.conv2d_nn.bias_add_qnn.requantize_", Composite="ethos-u.qnn_conv2d") -> Tensor[(1, 215, 228, 3), int8] {
    %4 = nn.pad(%FunctionVar_0_01, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */;
    %5 = qnn.conv2d(%4, meta[relay.Constant][4] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.00392157f /* ty=float32 span=Conv2D_12:0:0 */, meta[relay.Constant][5] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
    %6 = nn.bias_add(%5, meta[relay.Constant][6] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 215, 228, 3), int32] span=Conv2D_12:0:0 */;
    qnn.requantize(%6, meta[relay.Constant][7] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.0217672f /* ty=float32 span=Conv2D_12:0:0 */, -128 /* ty=int32 span=Conv2D_12:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 215, 228, 3), int8] span=Conv2D_12:0:0 */
  } /* ty=fn (Tensor[(1, 214, 227, 2), int8]) -> Tensor[(1, 215, 228, 3), int8] */;
  %8 = %3(%x) /* ty=Tensor[(1, 215, 228, 3), int8] */;
  %9 = %7(%x) /* ty=Tensor[(1, 215, 228, 3), int8] */;
  %10 = fn (%FunctionVar_0_0: Tensor[(1, 215, 228, 3), int8] /* ty=Tensor[(1, 215, 228, 3), int8] */, %FunctionVar_0_1: Tensor[(1, 215, 228, 3), int8] /* ty=Tensor[(1, 215, 228, 3), int8] */, PartitionedFromPattern="qnn.add_", Composite="ethos-u.add") -> Tensor[(1, 215, 228, 3), int8] {
    qnn.add(%FunctionVar_0_0, %FunctionVar_0_1, 0.0258486f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.0217672f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.046443f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */) /* ty=Tensor[(1, 215, 228, 3), int8] span=Identity:0:0 */
  } /* ty=fn (Tensor[(1, 215, 228, 3), int8], Tensor[(1, 215, 228, 3), int8]) -> Tensor[(1, 215, 228, 3), int8] */;
  %10(%8, %9) /* ty=Tensor[(1, 215, 228, 3), int8] */
}

(2) - An unsuccessful case, a run-time error, differs from a successful one by a single Pad operator at the beginning

Initial Relay:

def @main(%x: Tensor[(1, 214, 227, 2), int8] /* span=x:0:0 */, output_tensor_names=["Identity"]) {
  %0 = nn.pad(%x, -128f /* span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* span=Pad:0:0 */;
  %1 = nn.pad(%0, -128f /* span=Pad_1:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* span=Pad_1:0:0 */;
  %2 = qnn.conv2d(%1, meta[relay.Constant][0], -128 /* span=Conv2D1:0:0 */, 0 /* span=Conv2D1:0:0 */, 0.00392157f /* span=Conv2D1:0:0 */, meta[relay.Constant][1] /* span=Conv2D1:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* span=Conv2D1:0:0 */;
  %3 = nn.bias_add(%2, meta[relay.Constant][2], axis=3) /* span=Conv2D1:0:0 */;
  %4 = qnn.conv2d(%1, meta[relay.Constant][4], -128 /* span=Conv2D_12:0:0 */, 0 /* span=Conv2D_12:0:0 */, 0.00392157f /* span=Conv2D_12:0:0 */, meta[relay.Constant][5] /* span=Conv2D_12:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* span=Conv2D_12:0:0 */;
  %5 = nn.bias_add(%4, meta[relay.Constant][6], axis=3) /* span=Conv2D_12:0:0 */;
  %6 = qnn.requantize(%3, meta[relay.Constant][3] /* span=Conv2D1:0:0 */, 0 /* span=Conv2D1:0:0 */, 0.0258486f /* span=Conv2D1:0:0 */, -128 /* span=Conv2D1:0:0 */, axis=3, out_dtype="int8") /* span=Conv2D1:0:0 */;
  %7 = qnn.requantize(%5, meta[relay.Constant][7] /* span=Conv2D_12:0:0 */, 0 /* span=Conv2D_12:0:0 */, 0.0217672f /* span=Conv2D_12:0:0 */, -128 /* span=Conv2D_12:0:0 */, axis=3, out_dtype="int8") /* span=Conv2D_12:0:0 */;
  qnn.add(%6, %7, 0.0258486f /* span=Identity:0:0 */, -128 /* span=Identity:0:0 */, 0.0217672f /* span=Identity:0:0 */, -128 /* span=Identity:0:0 */, 0.046443f /* span=Identity:0:0 */, -128 /* span=Identity:0:0 */) /* span=Identity:0:0 */
}

Relay after a pass which duplicates Pads:

def @main(%x: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] span=x:0:0 */, output_tensor_names=["Identity"]) -> Tensor[(1, 216, 229, 3), int8] {
  %0 = nn.pad(%x, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */;
  %1 = nn.pad(%0, -128f /* ty=float32 span=Pad_1:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 216, 229, 2), int8] span=Pad_1:0:0 */;
  %2 = qnn.conv2d(%1, meta[relay.Constant][0] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.00392157f /* ty=float32 span=Conv2D1:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D1:0:0 */;
  %3 = nn.bias_add(%2, meta[relay.Constant][2] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D1:0:0 */;
  %4 = nn.pad(%0, -128f /* ty=float32 span=Pad_1:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 216, 229, 2), int8] span=Pad_1:0:0 */;
  %5 = qnn.conv2d(%4, meta[relay.Constant][4] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.00392157f /* ty=float32 span=Conv2D_12:0:0 */, meta[relay.Constant][5] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D_12:0:0 */;
  %6 = nn.bias_add(%5, meta[relay.Constant][6] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D_12:0:0 */;
  %7 = qnn.requantize(%3, meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.0258486f /* ty=float32 span=Conv2D1:0:0 */, -128 /* ty=int32 span=Conv2D1:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 216, 229, 3), int8] span=Conv2D1:0:0 */;
  %8 = qnn.requantize(%6, meta[relay.Constant][7] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.0217672f /* ty=float32 span=Conv2D_12:0:0 */, -128 /* ty=int32 span=Conv2D_12:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 216, 229, 3), int8] span=Conv2D_12:0:0 */;
  qnn.add(%7, %8, 0.0258486f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.0217672f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.046443f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */) /* ty=Tensor[(1, 216, 229, 3), int8] span=Identity:0:0 */
}

Relay after MergeComposite:

def @main(%x: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] span=x:0:0 */, output_tensor_names=["Identity"]) -> Tensor[(1, 216, 229, 3), int8] {
  %3 = fn (%FunctionVar_0_01: Tensor[(1, 214, 227, 2), int8] /* ty=Tensor[(1, 214, 227, 2), int8] */, PartitionedFromPattern="nn.pad_", Composite="ethos-u.pad2d") -> Tensor[(1, 215, 228, 2), int8] {
    nn.pad(%FunctionVar_0_01, -128f /* ty=float32 span=Pad:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 215, 228, 2), int8] span=Pad:0:0 */
  } /* ty=fn (Tensor[(1, 214, 227, 2), int8]) -> Tensor[(1, 215, 228, 2), int8] */;
  %4 = %3(%x) /* ty=Tensor[(1, 215, 228, 2), int8] */;
  %5 = fn (%FunctionVar_1_0: Tensor[(1, 215, 228, 2), int8] /* ty=Tensor[(1, 215, 228, 2), int8] */, PartitionedFromPattern="nn.pad_qnn.conv2d_nn.bias_add_qnn.requantize_", Composite="ethos-u.qnn_conv2d") -> Tensor[(1, 216, 229, 3), int8] {
    %0 = nn.pad(%FunctionVar_1_0, -128f /* ty=float32 span=Pad_1:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 216, 229, 2), int8] span=Pad_1:0:0 */;
    %1 = qnn.conv2d(%0, meta[relay.Constant][0] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.00392157f /* ty=float32 span=Conv2D1:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D1:0:0 */;
    %2 = nn.bias_add(%1, meta[relay.Constant][2] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D1:0:0 */;
    qnn.requantize(%2, meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=Conv2D1:0:0 */, 0 /* ty=int32 span=Conv2D1:0:0 */, 0.0258486f /* ty=float32 span=Conv2D1:0:0 */, -128 /* ty=int32 span=Conv2D1:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 216, 229, 3), int8] span=Conv2D1:0:0 */
  } /* ty=fn (Tensor[(1, 215, 228, 2), int8]) -> Tensor[(1, 216, 229, 3), int8] */;
  %9 = fn (%FunctionVar_0_02: Tensor[(1, 215, 228, 2), int8] /* ty=Tensor[(1, 215, 228, 2), int8] */, PartitionedFromPattern="nn.pad_qnn.conv2d_nn.bias_add_qnn.requantize_", Composite="ethos-u.qnn_conv2d") -> Tensor[(1, 216, 229, 3), int8] {
    %6 = nn.pad(%FunctionVar_0_02, -128f /* ty=float32 span=Pad_1:0:0 */, pad_width=[[0, 0], [0, 1], [0, 1], [0, 0]]) /* ty=Tensor[(1, 216, 229, 2), int8] span=Pad_1:0:0 */;
    %7 = qnn.conv2d(%6, meta[relay.Constant][4] /* ty=Tensor[(3, 2, 2, 3), int8] */, -128 /* ty=int32 span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.00392157f /* ty=float32 span=Conv2D_12:0:0 */, meta[relay.Constant][5] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, padding=[2, 0, 2, 1], dilation=[2, 1], channels=3, kernel_size=[3, 2], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D_12:0:0 */;
    %8 = nn.bias_add(%7, meta[relay.Constant][6] /* ty=Tensor[(3), int32] */, axis=3) /* ty=Tensor[(1, 216, 229, 3), int32] span=Conv2D_12:0:0 */;
    qnn.requantize(%8, meta[relay.Constant][7] /* ty=Tensor[(3), float32] span=Conv2D_12:0:0 */, 0 /* ty=int32 span=Conv2D_12:0:0 */, 0.0217672f /* ty=float32 span=Conv2D_12:0:0 */, -128 /* ty=int32 span=Conv2D_12:0:0 */, axis=3, out_dtype="int8") /* ty=Tensor[(1, 216, 229, 3), int8] span=Conv2D_12:0:0 */
  } /* ty=fn (Tensor[(1, 215, 228, 2), int8]) -> Tensor[(1, 216, 229, 3), int8] */;
  %10 = %5(%4) /* ty=Tensor[(1, 216, 229, 3), int8] */;
  %11 = %9(%4) /* ty=Tensor[(1, 216, 229, 3), int8] */;
  %12 = fn (%FunctionVar_0_0: Tensor[(1, 216, 229, 3), int8] /* ty=Tensor[(1, 216, 229, 3), int8] */, %FunctionVar_0_1: Tensor[(1, 216, 229, 3), int8] /* ty=Tensor[(1, 216, 229, 3), int8] */, PartitionedFromPattern="qnn.add_", Composite="ethos-u.add") -> Tensor[(1, 216, 229, 3), int8] {
    qnn.add(%FunctionVar_0_0, %FunctionVar_0_1, 0.0258486f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.0217672f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */, 0.046443f /* ty=float32 span=Identity:0:0 */, -128 /* ty=int32 span=Identity:0:0 */) /* ty=Tensor[(1, 216, 229, 3), int8] span=Identity:0:0 */
  } /* ty=fn (Tensor[(1, 216, 229, 3), int8], Tensor[(1, 216, 229, 3), int8]) -> Tensor[(1, 216, 229, 3), int8] */;
  %12(%10, %11) /* ty=Tensor[(1, 216, 229, 3), int8] */
}

Run-time error:

/tmp/tmpjnspu08b/test/build/test.c:34:89: warning: control reaches end of non-void function [-Wreturn-type]
E              34 | TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
E                 |                                                                                         ^
E           /opt/arm/gcc-arm-none-eabi/bin/../lib/gcc/arm-none-eabi/10.2.1/../../../../arm-none-eabi/bin/ld: /tmp/tmpjnspu08b/test/build/aot_test_runner section `.bss.noinit' will not fit in region `SRAM'
E           /opt/arm/gcc-arm-none-eabi/bin/../lib/gcc/arm-none-eabi/10.2.1/../../../../arm-none-eabi/bin/ld: region `SRAM' overflowed by 270144 bytes
E           collect2: error: ld returned 1 exit status
E           make: *** [/root/tvm/python/tvm/testing/../../../tests/python/relay/aot/corstone300.mk:154: /tmp/tmpjnspu08b/test/build/aot_test_runner] Error 1

PadsWithMultipleConsumersReplicator pass

class PadsWithMultipleConsumersReplicator(ExprMutator):

    """A pass to to handle the situation when nn.pad operator has more than one consumer

    among such operators as nn.qnn.conv2d. Pads are added so that each pad

    has only one consumer.

    """

    def __init__(self):

        ExprMutator.__init__(self)        

        self.hashes = set()      

   

    def visit_call(self, call):

        if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], Call) and isinstance(call.args[0].op, tvm.ir.Op) and\

            call.op == relay.op.get("qnn.conv2d") and call.args[0].op == relay.op.get("nn.pad"):

                if tvm.ir.structural_hash(call.args[0]) not in self.hashes:

                    self.hashes.add(tvm.ir.structural_hash(call.args[0]))

                else:

                    used_pad = self.visit(call.args[0])

                    used_pad_args = [self.visit(arg) for arg in used_pad.args]

                    new_pad = Call(used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span)

                    new_pad = self.visit(new_pad)

                    new_conv2d_args = []

                    for i, arg in enumerate(call.args):

                        if i == 0:

                            new_conv2d_args.append(self.visit(new_pad))

                        else:

                            new_conv2d_args.append(self.visit(arg))

                    new_conv2d_op = self.visit(call.op)

                    expr__ = _expr.CallWithFields(

                        call, new_conv2d_op, new_conv2d_args, call.attrs, call.type_args, None, call.span

                    )

                    return expr__

               

        new_args = [self.visit(arg) for arg in call.args]

        new_op = self.visit(call.op)

        expr__ = _expr.CallWithFields(

            call, new_op, new_args, call.attrs, call.type_args, None, call.span

        )

        return expr__

def ReplicatePadsWithMultipleConsumers(mod):

    """Traverses the Relay graph to replicate nn.pad operators if thay have

    multiple consumers among operators such as nn.qnn.conv2d.

    That making remove the situation when e.g. pad+conv2d corresponds qnn_conv2d_pattern,

    but can not be grouped because several conv2d use the same pad operation.

   

    Parameters

    ----------

    tvm.ir.IRModule

        The IRModule that gets generated from a relay frontend.

    Returns

    -------

    tvm.ir.IRModule

        The IRModule without nn.pad operators with multiple consumers.

    """

    replicator = PadsWithMultipleConsumersReplicator()

    for global_var, func in mod.functions.items():

        func = replicator.visit(func)

        mod.update_func(global_var, func)

    return mod

@comaniac

Appreciate for any advice!

I figured out what the problem is, I use the volume tensor for the test, and when I add an extra pad, there just isn’t enough memory for three buffers:

@I.ir_module
class Module:
    @T.prim_func
    def main(ethos_u_0_i0: T.Buffer((1, 214, 227, 2), "int8"), ethosu_write: T.Buffer((1, 216, 229, 3), "int8")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
        p2_global = T.allocate([48], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)})
        p5_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)})
        p8_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": T.bool(True)})
        ethosu_write_1 = T.allocate([784320], "int8", "global", annotations={"disable_lower_builtin": T.bool(True)})
        ethosu_write_2 = T.allocate([791424], "int8", "global", annotations={"disable_lower_builtin": T.bool(True)})
        ethosu_write_3 = T.allocate([791424], "int8", "global", annotations={"disable_lower_builtin": T.bool(True)})
        ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
        p2_global_1 = T.Buffer((48,), "uint8", data=p2_global)
        with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, "DataPar", ""), "pragma_compute_cycles_hint", 352):
            p1_encoded = T.Buffer((48,), "uint8")
            T.call_extern("handle", "ethosu_copy", p1_encoded[0], 48, p2_global_1[0])
        ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
        p5_global_1 = T.Buffer((128,), "uint8", data=p5_global)
        with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_1, None, "DataPar", ""), "pragma_compute_cycles_hint", 1056):
            p4_encoded = T.Buffer((128,), "uint8")
            T.call_extern("handle", "ethosu_copy", p4_encoded[0], 128, p5_global_1[0])
        nn = T.int32()
        ethosu_write_4 = T.Buffer((784320,), "int8", data=ethosu_write_1)
        with T.attr(T.iter_var(nn, None, "DataPar", ""), "pragma_compute_cycles_hint", T.int64(196416)):
            ethos_u_0_i0_1 = T.Buffer((97156,), "int8", data=ethos_u_0_i0.data)
            T.call_extern("handle", "ethosu_depthwise_conv2d", "int8", 214, 227, 2, 214, 0, 227, ethos_u_0_i0_1[0], 0, 0, 0, T.float32(1), -128, "NHWC", 454, 2, 1, "int8", 215, 228, 2, 215, 0, 228, ethosu_write_4[0], 0, 0, 0, T.float32(1), -128, "NHCWB16", 3648, 16, 1, 1, 1, 1, 1, 1, 1, p2_global_1[0], 16, 0, p2_global_1[16], 32, 0, 0, 1, 1, "NONE", 0, 0, "TFL", "NONE", 14, 12, 16)
        ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
        p8_global_1 = T.Buffer((128,), "uint8", data=p8_global)
        with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_2, None, "DataPar", ""), "pragma_compute_cycles_hint", 1056):
            p7_encoded = T.Buffer((128,), "uint8")
            T.call_extern("handle", "ethosu_copy", p7_encoded[0], 128, p8_global_1[0])
        nn_1 = T.int32()
        ethosu_write_5 = T.Buffer((791424,), "int8", data=ethosu_write_2)
        with T.attr(T.iter_var(nn_1, None, "DataPar", ""), "pragma_compute_cycles_hint", T.int64(198880)):
            T.call_extern("handle", "ethosu_conv2d", "int8", 215, 228, 2, 215, 0, 228, ethosu_write_4[0], 0, 0, 0, T.float32(0.0039215679280459881), -128, "NHCWB16", 3648, 16, 1, "int8", 216, 229, 3, 216, 0, 229, ethosu_write_5[0], 0, 0, 0, T.float32(0.025848578661680222), -128, "NHCWB16", 3664, 16, 1, 2, 3, 1, 1, 1, 2, p5_global_1[0], 96, T.int8(-1), T.int8(-1), 0, p5_global_1[96], 32, T.int8(-1), T.int8(-1), 2, 0, 3, 2, "NONE", 0, 0, "TFL", "NONE", 32, 8, 16)
        nn_2 = T.int32()
        ethosu_write_6 = T.Buffer((791424,), "int8", data=ethosu_write_3)
        with T.attr(T.iter_var(nn_2, None, "DataPar", ""), "pragma_compute_cycles_hint", T.int64(198880)):
            T.call_extern("handle", "ethosu_conv2d", "int8", 215, 228, 2, 215, 0, 228, ethosu_write_4[0], 0, 0, 0, T.float32(0.0039215679280459881), -128, "NHCWB16", 3648, 16, 1, "int8", 216, 229, 3, 216, 0, 229, ethosu_write_6[0], 0, 0, 0, T.float32(0.021767223253846169), -128, "NHCWB16", 3664, 16, 1, 2, 3, 1, 1, 1, 2, p8_global_1[0], 96, T.int8(-1), T.int8(-1), 0, p8_global_1[96], 32, T.int8(-1), T.int8(-1), 2, 0, 3, 2, "NONE", 0, 0, "TFL", "NONE", 32, 8, 16)
        nn_3 = T.int32()
        T.attr(T.iter_var(nn_3, None, "DataPar", ""), "pragma_compute_cycles_hint", T.int64(49464))
        ethosu_write_7 = T.Buffer((148392,), "int8", data=ethosu_write.data)
        T.call_extern("handle", "ethosu_binary_elementwise", "int8", 216, 229, 3, 216, 0, 229, ethosu_write_5[0], 0, 0, 0, T.float32(0.025848578661680222), -128, "NHCWB16", 3664, 16, 1, "int8", 216, 229, 3, 216, 0, 229, ethosu_write_6[0], 0, 0, 0, T.float32(0.021767223253846169), -128, "NHCWB16", 3664, 16, 1, "int8", 216, 229, 3, 216, 0, 229, ethosu_write_7[0], 0, 0, 0, T.float32(0.046443037688732147), -128, "NHWC", 687, 3, 1, "ADD", 0, "NONE", 0, 0, "TFL", 8, 64, 8, 0, 0, 0)