Raise Segmentation Fault when using Virtual Machine

Hello, all!

I imitated the official tutorial Compile PyTorch Object Detection Models to optimize a TensorFlow objection detection model by TVM virual machine. The model is ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03, downloaded from TensorFlow 1 Detection Model Zoo. During calling relay.vm.compile, the program raised Segmentation fault.

The logs are:

WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, workload=('conv2d_nhwc.cuda', ('TENSOR', (1, 80, 80, 256), 'float32'), ('TENSOR', (3, 3, 256, 546), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, workload=('conv2d_nhwc.cuda', ('TENSOR', (1, 40, 40, 256), 'float32'), ('TENSOR', (3, 3, 256, 546), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, workload=('conv2d_nhwc.cuda', ('TENSOR', (1, 20, 20, 256), 'float32'), ('TENSOR', (3, 3, 256, 546), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, workload=('conv2d_nhwc.cuda', ('TENSOR', (1, 10, 10, 256), 'float32'), ('TENSOR', (3, 3, 256, 546), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32, workload=('conv2d_nhwc.cuda', ('TENSOR', (1, 5, 5, 256), 'float32'), ('TENSOR', (3, 3, 256, 546), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32'). A fallback configuration is used, which may bring great performance regression.
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_squeeze_shape_func, 0x7e59d40)
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_broadcast_shape_func, 0x7e5a7f0)
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_broadcast_shape_func, 0x7e5bf90)
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_split_shape_func, 0xd936350)
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_split_shape_func, 0x7e678a0)
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_broadcast_shape_func, 0x7e5ecc0)
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_split_shape_func, 0xd943dc0)
[16:05:17] /root/code/optimize_object_detection_by_tvm/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = hybrid(_split_shape_func, 0x7e69970)
Segmentation fault (core dumped)

I only copy the last few lines here.

The I opened a corefile by gdb to check the function calling stack. It showed:

#0  0x00007f5dc437d575 in dmlc::LogCheck_EQ<unsigned long, int> (x=<error reading variable: Cannot access memory at address 0x7ffe49abef60>, 
    y=<error reading variable: Cannot access memory at address 0x7ffe49abef58>)
    at /root/code/optimize_object_detection_by_tvm/tvm/3rdparty/dmlc-core/include/dmlc/logging.h:217
#1  0x00007f5dc511ee80 in tvm::relay::DeDupMutator::Fresh (this=0x7ffe4a2ba050, v=...)
    at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/transforms/de_duplicate.cc:43
#2  0x00007f5dc511f3b8 in tvm::relay::DeDupMutator::VisitExpr_ (this=0x7ffe4a2ba050, op=0x10128950)
    at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/transforms/de_duplicate.cc:64
#3  0x00007f5dc44fac09 in tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)#7}::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const (
    __closure=0x0, n=..., self=0x7ffe4a2ba058) at /root/code/optimize_object_detection_by_tvm/tvm/include/tvm/relay/expr_functor.h:126
#4  0x00007f5dc44fac64 in tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)#7}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) ()
    at /root/code/optimize_object_detection_by_tvm/tvm/include/tvm/relay/expr_functor.h:126
#5  0x00007f5dc44fb72c in tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const (
    this=0x7f5dc9290a60 <tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)::vtable>, n=..., args#0=0x7ffe4a2ba058)
    at /root/code/optimize_object_detection_by_tvm/tvm/include/tvm/node/functor.h:97
#6  0x00007f5dc44f9a98 in tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) (this=0x7ffe4a2ba058, n=...)
    at /root/code/optimize_object_detection_by_tvm/tvm/include/tvm/relay/expr_functor.h:92
#7  0x00007f5dc52682a6 in tvm::relay::ExprMutator::VisitExpr (this=0x7ffe4a2ba058, expr=...)
    at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/ir/expr_functor.cc:192
#8  0x00007f5dc511f227 in tvm::relay::DeDupMutator::DispatchVisitExpr (this=0x7ffe4a2ba050, e=...)
    at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/transforms/de_duplicate.cc:51
#9  0x00007f5dc5267ee2 in tvm::relay::MixedModeMutator::VisitLeaf (this=0x7ffe4a2ba058, expr=...)
    at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/ir/expr_functor.cc:145
#10 0x00007f5dc5268052 in tvm::relay::MixedModeMutator::<lambda(const Expr&)>::operator()(const tvm::relay::Expr &) const (__closure=0x7ffe49abf4b0, expr=...)
    at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/ir/expr_functor.cc:162
#11 0x00007f5dc526ca50 in tvm::relay::ExpandDataflow<tvm::relay::MixedModeMutator::VisitExpr(const Expr&)::<lambda(const Expr&)>, tvm::relay::MixedModeMutator::VisitExpr(const Expr&)::<lambda(const Expr&)> >(tvm::relay::Expr, tvm::relay::MixedModeMutator::<lambda(const Expr&)>, tvm::relay::MixedModeMutator::<lambda(const Expr&)>) (
    expr=..., fcheck_visited=..., fvisit_leaf=...) at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/ir/expr_functor.cc:98
#12 0x00007f5dc52680fe in tvm::relay::MixedModeMutator::VisitExpr (this=0x7ffe4a2ba058, expr=...)
    at /root/code/optimize_object_detection_by_tvm/tvm/src/relay/ir/expr_functor.cc:166

There are thousands of lines of information and I only display first 12 lines here.

It seems that the error was raised at tvm/3rdparty/dmlc-core/include/dmlc/logging.h:217.

My code is:

import os
from PIL import Image
import numpy as np
import tensorflow as tf

import tvm
from tvm import relay
from tvm import te
from tvm.contrib import graph_runtime
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download

from tvm_serving.frontend.tensorflow import graph_utils

work_dir = "/root/code/optimize_object_detection_by_tvm/tvm_serving/examples/mask_rcnn"

image_name = "COCO_train2014_000000581921.jpg"

output_node_names = ["num_detections", "detection_boxes", "detection_scores", "detection_classes"]


def get_input():
  image_path = os.path.join(work_dir, image_name)
  image = np.array(Image.open(image_path)) # dtype=uint8
  image = np.expand_dims(image, axis=0)
  return image # shape=(1, 427, 640, 3)


def optimize_by_tvm():
  graph_def_path = './frozen_inference_graph.pb'
  graph_def = tf.GraphDef()
  with open(graph_def_path, 'rb') as f:
    graph_def.ParseFromString(f.read())  
  
  x = get_input()
  input_name = "image_tensor"
  dtype = "uint8"
  shape_dict = {input_name: x.shape} # shape=(1, 427, 640, 3)
  dtype_dict = {input_name: dtype} # dtype=uint8

  target = "llvm"
  target_host = "llvm"
  layout = "NCHW"
  ctx = tvm.cpu()
  
  mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, 
                                               shape=shape_dict, outputs=output_node_names)
  with tvm.transform.PassContext(opt_level=3):
    vm_exec = relay.vm.compile(mod, target=target, params=params)
  vm = VirtualMachine(vm_exec, ctx)
  vm.set_input("image_tensor", **{input_name: img})
  tvm_res = vm.run()
  

if __name__ == '__main__':
  optimize_by_tvm()

Could anybody please help me with this problem? Thanks!

Thanks for reporting the issue!

Can you include what version of TF you’re using? I’m trying to reproduce, but TF 2.3.1 doesn’t like your script.

Traceback (most recent call last):
  File "test.py", line 55, in <module>
    optimize_by_tvm()
  File "test.py", line 30, in optimize_by_tvm
    graph_def = tf.GraphDef()
AttributeError: module 'tensorflow' has no attribute 'GraphDef'

Thanks for reply!
My TensorFlow version is 1.15.3。I downloaded source code from tensorflow release and build a .whl from source.
If your TensorFlow is 2.3.1, I think tf.compat.v1.GraphDef() will be fine.

Try increasing stack size limit?

Thank you! I solved this problem by increasing stack size limit.

How do you increase the stack size limit?

My OS is Ubuntu16.04. I use ulimit -s 65535 to set stack size limit to 64MB.

Hmm, I wonder how we’re getting stack overflow with the mixed mode mutator, I wouldn’t expect that deep of a recursive stack in this pass.