Segmentation fault after creating ACL library with relay arm_compute_lib

Hi,

I am trying to build an arm library using ACL and pytorch scripted model. I can create the library without problems using the aarch64-linux-gnu-gcc compiler, but when I try to run the library in the raspberry pi 4, I am getting a Segmentation fault error, so I don’t understand what it is the problem. Here is my code:

To create the library in my PC:

import tvm
from tvm import relay
import torch.nn as nn
import torch.nn.functional as F
import torch

class CovNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

pc = False
data_type = "float32"
data_shape = (1, 3, 32, 32)
input_name = "input"  # the input name can be arbitrary for PyTorch frontend.
input_shapes = [(input_name, data_shape)]

input = torch.randn(data_shape, dtype=torch.float32)
net = CovNet()
script_module = torch.jit.trace(net.forward, input).eval()

mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
module = partition_for_arm_compute_lib(mod)

target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon -mcpu=cortex-a72"

with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
    lib = relay.build(module, target=target)

lib_path = 'lib_acl_conv_net.so'
cross_compile = 'aarch64-linux-gnu-gcc'
lib.export_library(lib_path, cc=cross_compile)

To run the library in RPi4:

import time
import tvm
from tvm.contrib import graph_runtime
import numpy as np

data_type = "float32"
data_shape = (1, 3, 32, 32)

dev = tvm.cpu(0)
loaded_lib = tvm.runtime.load_module('lib_acl_conv_net.so')
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev))
d_data = np.random.uniform(0, 1, data_shape).astype(data_type)
map_inputs = {'data': d_data}
gen_module.set_input(**map_inputs)

timeList = []
for i in range(15):
    now = time.time()

    gen_module.run()

    timeList.append(time.time() - now)

floats_array = np.array(timeList)
np.set_printoptions(precision=3)
print('Execution list:', floats_array)

Any help very much appreciated!!!

can someone give me some help? thank you

Really, nobody knows!!!

Hi @i02fesea, it looks as though the input graph is in NCHW format which is not supported by the Arm Compute Library (ACL) integration. It may be the case that there aren’t sufficient checks in the integration and these incompatible operations are being offloaded to ACL, rather than being offloaded to TVM as they should be.

I hope to get some time to debug this in more detail next week, but in the meantime you could try running an NHWC graph, or convert the layout of the current graph before partitioning for ACL by using something similar to:

mod = relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(mod)
module = partition_for_arm_compute_lib(mod)

Thanks a lot for your answer. I have tried your modification but I am getting the following error, when I am calling ConvertLayout:

Check failed: (checked_type_.defined()) is false: internal error: the type checker has not populated the checked_type field for Var(input, ty=TensorType([1, 3, 32, 32], float32))

My apologies, type inference also needs to be run before the convert layout pass. Could you try?

mod = relay.transform.InferType()(mod)
mod = relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(mod)
module = partition_for_arm_compute_lib(mod)

Thanks Luke, I have added your change to generate the library:

mod = relay.transform.InferType()(mod)
    mod = relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(mod)

    from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
    module = partition_for_arm_compute_lib(mod)
    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon -mcpu=cortex-a72"

    with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
        lib = relay.build(module, target=target)

    lib_path = 'lib_acl_conv_net.so'
    cross_compile = 'aarch64-linux-gnu-gcc'
    lib.export_library(lib_path, cc=cross_compile)

and I have changed the input to load and run the library:

data_shape = (1, 32, 32, 3)

dev = tvm.cpu(0)
loaded_lib = tvm.runtime.load_module('lib_acl_conv_net.so')
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev))
d_data = np.random.uniform(0, 1, data_shape).astype(data_type)
map_inputs = {'data': d_data}
gen_module.set_input(**map_inputs)
gen_module.run()

but I am still getting the Segmentation fault error :frowning:

Any chance you can share the backtrace of the segmentation fault?

do you know how to run with verbose mode? It doesn’t show anything apart from Segmentation fault when I call this line “gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib’default’)”

Hi Luke, do you have any suggestion to get some error information?

Hi @i02fesea,

Apologies for my delayed response, I finally found some time to test out and run your code example, although I don’t have access to a RPi 4 so there may be some subtle differences on the runtime side of things. It turns out that since the graph is in NCHW format originally, there are a few hacks needed to get it into a format that ACL expects (note that the implementation is only expected to handle NHWC input graphs). So I ended up preprocessing the graph like this:

from tvm.relay.build_module import bind_params_by_name
mod["main"] = bind_params_by_name(mod["main"], params) # binds constants so they can be constant folded later.
mod = relay.transform.InferType()(mod) # gets type information so we can use convert layout.
mod = relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"], "nn.max_pool2d": ["NHWC"]})(mod) # convert the layout to NHWC, OHWI is the format that ACL expects (weight layout conversion should happen in partition_for_arm_compute_lib but it seems as though there is a bug here preventing this from happening correctly).
mod = relay.transform.CanonicalizeOps()(mod) # simplifies nn.bias_add into an add operation - this is to work around a bug.
mod = relay.transform.FoldConstant()(mod) # fold the weight layout transforms so the pattern matching can correctly identify offloadable ops.
module = partition_for_arm_compute_lib(mod)
...

Unfortunately this still isn’t perfect due to there being quite a few layout transforms in the output:

def @main(%input: Tensor[(1, 3, 32, 32), float32] /* ty=Tensor[(1, 3, 32, 32), float32] */) -> Tensor[(1, 10), float32] {
  %0 = layout_transform(%input, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 32, 32, 3), float32] */;
  %1 = @tvmgen_default_arm_compute_lib_main_0(%0) /* ty=Tensor[(1, 28, 28, 6), float32] */;
  %2 = layout_transform(%1, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 6, 28, 28), float32] */;
  %3 = @tvmgen_default_arm_compute_lib_main_1(%2) /* ty=Tensor[(1, 6, 28, 28), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 6, 28, 28), float32] */;
  %5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 6, 14, 14), float32] */;
  %6 = layout_transform(%5, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 14, 14, 6), float32] */;
  %7 = @tvmgen_default_arm_compute_lib_main_3(%6) /* ty=Tensor[(1, 10, 10, 16), float32] */;
  %8 = layout_transform(%7, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 16, 10, 10), float32] */;
  %9 = @tvmgen_default_arm_compute_lib_main_4(%8) /* ty=Tensor[(1, 16, 10, 10), float32] */;
  %10 = nn.relu(%9) /* ty=Tensor[(1, 16, 10, 10), float32] */;
  %11 = nn.max_pool2d(%10, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 5, 5), float32] */;
  %12 = @tvmgen_default_arm_compute_lib_main_6(%11) /* ty=Tensor[(1, 400, 1, 1), float32] */;
  %13 = squeeze(%12, axis=[2, 3]) /* ty=Tensor[(1, 400), float32] */;
  %14 = @tvmgen_default_arm_compute_lib_main_7(%13) /* ty=Tensor[(1, 120), float32] */;
  %15 = @tvmgen_default_arm_compute_lib_main_8(%14) /* ty=Tensor[(1, 120), float32] */;
  %16 = nn.relu(%15) /* ty=Tensor[(1, 120), float32] */;
  %17 = @tvmgen_default_arm_compute_lib_main_10(%16) /* ty=Tensor[(1, 84), float32] */;
  %18 = @tvmgen_default_arm_compute_lib_main_11(%17) /* ty=Tensor[(1, 84), float32] */;
  %19 = nn.relu(%18) /* ty=Tensor[(1, 84), float32] */;
  %20 = @tvmgen_default_arm_compute_lib_main_13(%19) /* ty=Tensor[(1, 10), float32] */;
  @tvmgen_default_arm_compute_lib_main_14(%20) /* ty=Tensor[(1, 10), float32] */
}

To get the best performance I would recommend building your input graph in NHWC format originally, rather than in NCHW.

I suspect the segmentation fault you were seeing before was because the weights were in the wrong data layout format, this is a bug and I’ll look at fixing it when I have some free cycles. To get more information from the segmentation fault I recommend you build TVM with debug symbols and use a debug tool such as gdb to view the backtrace.

Hope this helps!

Thanks a lot Luke, it will definitely help. I will test your code and do as well the modification that you have suggested.