Hi I have a create a simple nn model
class ConvBNReLU(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=True)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(16*64*64, 10)
self.__init_weights()
But when I use the parition_for_arm_compute_lib API using the following code:
mod, params = relay.frontend.from_onnx(onnx_model)
mod.show()
from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
mod = partition_for_arm_compute_lib(mod, params)
mod.show()
The output is like this, in the ouput we can clearly see the conv2d operator is not dispatched ACL. I am wondering why this happen given the conv2d is supported in ACL?
# model before parition
/home/tvm/python/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user
warnings.warn(
def @main(%input: Tensor[(1, 3, 64, 64), float32] /* ty=Tensor[(1, 3, 64, 64), float32] span=/conv/Conv.input:0:0 */) -> Tensor[(1, 10), float32] {
%0 = nn.conv2d(%input, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] span=/conv/Conv.onnx::Conv_23:0:0 */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
%1 = nn.bias_add(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float32] span=/conv/Conv.onnx::Conv_24:0:0 */) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 16, 64, 64), float32] span=/relu/Relu:0:0 */;
%3 = reshape(%2, newshape=[1, -1]) /* ty=Tensor[(1, 65536), float32] span=/Reshape:0:0 */;
%4 = nn.dense(%3, meta[relay.Constant][2] /* ty=Tensor[(10, 65536), float32] span=/fc/Gemm.fc.weight:0:0 */, units=10) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */;
add(%4, meta[relay.Constant][3] /* ty=Tensor[(10), float32] span=/fc/Gemm.fc.bias:0:0 */) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */
}
# model after partition
/home/tvm/python/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user
warnings.warn(
def @main(%input: Tensor[(1, 3, 64, 64), float32] /* ty=Tensor[(1, 3, 64, 64), float32] span=/conv/Conv.input:0:0 */) -> Tensor[(1, 10), float32] {
%0 = nn.conv2d(%input, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] span=/conv/Conv.onnx::Conv_23:0:0 */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
%1 = nn.bias_add(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float32] span=/conv/Conv.onnx::Conv_24:0:0 */) /* ty=Tensor[(1, 16, 64, 64), float32] span=/conv/Conv:0:0 */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 16, 64, 64), float32] span=/relu/Relu:0:0 */;
%3 = @tvmgen_default_arm_compute_lib_main_0(%2) /* ty=Tensor[(1, 65536), float32] */;
%4 = @tvmgen_default_arm_compute_lib_main_1(%3) /* ty=Tensor[(1, 10), float32] */;
@tvmgen_default_arm_compute_lib_main_2(%4) /* ty=Tensor[(1, 10), float32] */
}
def @tvmgen_default_arm_compute_lib_main_0(%arm_compute_lib_0_i0: Tensor[(1, 16, 64, 64), float32] /* ty=Tensor[(1, 16, 64, 64), float32] */, Compiler="arm_compute_lib", Primitive=1, Inline=1, global_symbol="tvmgen_default_arm_compute_lib_main_0") -> Tensor[(1, 65536), float32] {
reshape(%arm_compute_lib_0_i0, newshape=[1, -1]) /* ty=Tensor[(1, 65536), float32] span=/Reshape:0:0 */
}
def @tvmgen_default_arm_compute_lib_main_1(%arm_compute_lib_1_i0: Tensor[(1, 65536), float32] /* ty=Tensor[(1, 65536), float32] */, Compiler="arm_compute_lib", Primitive=1, Inline=1, global_symbol="tvmgen_default_arm_compute_lib_main_1") -> Tensor[(1, 10), float32] {
%5 = fn (%FunctionVar_0_0: Tensor[(1, 65536), float32] /* ty=Tensor[(1, 65536), float32] */, PartitionedFromPattern="nn.dense_", Composite="arm_compute_lib.dense") -> Tensor[(1, 10), float32] {
nn.dense(%FunctionVar_0_0, meta[relay.Constant][2] /* ty=Tensor[(10, 65536), float32] span=/fc/Gemm.fc.weight:0:0 */, units=10) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */
} /* ty=fn (Tensor[(1, 65536), float32]) -> Tensor[(1, 10), float32] */;
%5(%arm_compute_lib_1_i0) /* ty=Tensor[(1, 10), float32] */
}
def @tvmgen_default_arm_compute_lib_main_2(%arm_compute_lib_2_i0: Tensor[(1, 10), float32] /* ty=Tensor[(1, 10), float32] */, Compiler="arm_compute_lib", Primitive=1, Inline=1, global_symbol="tvmgen_default_arm_compute_lib_main_2") -> Tensor[(1, 10), float32] {
add(%arm_compute_lib_2_i0, meta[relay.Constant][3] /* ty=Tensor[(10), float32] span=/fc/Gemm.fc.bias:0:0 */) /* ty=Tensor[(1, 10), float32] span=/fc/Gemm:0:0 */
}