[Arm Compute Lib] Failed to dispatch conv2d operator

Hi I have a create a simple nn model

class ConvBNReLU(nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=True)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(16*64*64, 10)

        self.__init_weights()

But when I use the parition_for_arm_compute_lib API using the following code:

mod, params = relay.frontend.from_onnx(onnx_model)
mod.show()
from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
mod = partition_for_arm_compute_lib(mod, params)
mod.show()

The output is like this, in the ouput we can clearly see the conv2d operator is not dispatched ACL. I am wondering why this happen given the conv2d is supported in ACL?

# model before parition    
 /home/tvm/python/tvm/script/highlight.py:117: UserWarning: No module named 'black'
        To print formatted TVM script, please install the formatter 'Black':
        /usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user
          warnings.warn(
    def @main(%input: Tensor[(1, 3, 64, 64), float32] /* ty=Tensor[(1, 3, 64, 64), float32] span=/conv/Conv.input:0:0 */) -> Tensor[(1, 10), float32] {
      %0 = nn.conv2d(%input, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] span=/conv/Conv.onnx::Conv_23:0:0 */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
      %1 = nn.bias_add(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float32] span=/conv/Conv.onnx::Conv_24:0:0 */) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
      %2 = nn.relu(%1) /* ty=Tensor[(1, 16, 64, 64), float32] span=/relu/Relu:0:0 */;
      %3 = reshape(%2, newshape=[1, -1]) /* ty=Tensor[(1, 65536), float32] span=/Reshape:0:0 */;
      %4 = nn.dense(%3, meta[relay.Constant][2] /* ty=Tensor[(10, 65536), float32] span=/fc/Gemm.fc.weight:0:0 */, units=10) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */;
      add(%4, meta[relay.Constant][3] /* ty=Tensor[(10), float32] span=/fc/Gemm.fc.bias:0:0 */) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */
    }
# model after partition
    /home/tvm/python/tvm/script/highlight.py:117: UserWarning: No module named 'black'
    To print formatted TVM script, please install the formatter 'Black':
    /usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user
      warnings.warn(
def @main(%input: Tensor[(1, 3, 64, 64), float32] /* ty=Tensor[(1, 3, 64, 64), float32] span=/conv/Conv.input:0:0 */) -> Tensor[(1, 10), float32] {
  %0 = nn.conv2d(%input, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] span=/conv/Conv.onnx::Conv_23:0:0 */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
  %1 = nn.bias_add(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float32] span=/conv/Conv.onnx::Conv_24:0:0 */) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 16, 64, 64), float32] span=/relu/Relu:0:0 */;
  %3 = @tvmgen_default_arm_compute_lib_main_0(%2) /* ty=Tensor[(1, 65536), float32] */;
  %4 = @tvmgen_default_arm_compute_lib_main_1(%3) /* ty=Tensor[(1, 10), float32] */;
  @tvmgen_default_arm_compute_lib_main_2(%4) /* ty=Tensor[(1, 10), float32] */
}

def @tvmgen_default_arm_compute_lib_main_0(%arm_compute_lib_0_i0: Tensor[(1, 16, 64, 64), float32] /* ty=Tensor[(1, 16, 64, 64), float32] */, Compiler="arm_compute_lib", Primitive=1, Inline=1, global_symbol="tvmgen_default_arm_compute_lib_main_0") -> Tensor[(1, 65536), float32] {
  reshape(%arm_compute_lib_0_i0, newshape=[1, -1]) /* ty=Tensor[(1, 65536), float32] span=/Reshape:0:0 */
}

def @tvmgen_default_arm_compute_lib_main_1(%arm_compute_lib_1_i0: Tensor[(1, 65536), float32] /* ty=Tensor[(1, 65536), float32] */, Compiler="arm_compute_lib", Primitive=1, Inline=1, global_symbol="tvmgen_default_arm_compute_lib_main_1") -> Tensor[(1, 10), float32] {
  %5 = fn (%FunctionVar_0_0: Tensor[(1, 65536), float32] /* ty=Tensor[(1, 65536), float32] */, PartitionedFromPattern="nn.dense_", Composite="arm_compute_lib.dense") -> Tensor[(1, 10), float32] {
    nn.dense(%FunctionVar_0_0, meta[relay.Constant][2] /* ty=Tensor[(10, 65536), float32] span=/fc/Gemm.fc.weight:0:0 */, units=10) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */
  } /* ty=fn (Tensor[(1, 65536), float32]) -> Tensor[(1, 10), float32] */;
  %5(%arm_compute_lib_1_i0) /* ty=Tensor[(1, 10), float32] */
}

def @tvmgen_default_arm_compute_lib_main_2(%arm_compute_lib_2_i0: Tensor[(1, 10), float32] /* ty=Tensor[(1, 10), float32] */, Compiler="arm_compute_lib", Primitive=1, Inline=1, global_symbol="tvmgen_default_arm_compute_lib_main_2") -> Tensor[(1, 10), float32] {
  add(%arm_compute_lib_2_i0, meta[relay.Constant][3] /* ty=Tensor[(10), float32] span=/fc/Gemm.fc.bias:0:0 */) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */
}

Hi @digital-nomad-cheng,

Arm Compute Library only offloads NHWC convolutions so the input graph must be in NHWC format. It looks like the graph you are trying to partition is in NCHW format. Please could you try running the ConvertLayout pass before partition_for_arm_compute_lib.

1 Like

Thanks for the rely, solved my problem.