Segmentation fault (core dumped) for relay.build

When the following simple script was executed, It crashed and threw “Segmentation fault (core dumped)”.

Is there a bug in my script or an internal error in TVM?

Wish your comments, Thanks!

script

import tvm
from tvm import relay

mod = tvm.IRModule()
var_0 = relay.var("var_0", dtype = "float64", shape = (8,)) # shape=(8,)
var_1 = relay.log(var_0.astype('float64')) # shape=(8,)

func = relay.Function([var_0,], var_1)
output = relay.Tuple([func,])
F = relay.Function([], output)
mod['main'] = F
#mod = relay.transform.InferType()(mod)
print(mod.astext(show_meta_data=False))
graph, lib, params = relay.build(mod, target='llvm')

Yes, the segfault happens at https://github.com/apache/tvm/blob/2f937801de21757a409d9d193f43fe5c0ab5a202/src/relay/backend/utils.cc#L141-L142 because your tuple element is a function, not a tensor. cc @manupa-arm @mbs-octoml

But why do you want to put a function in a tuple?

I have met one similar problem when I put functions into relay.Tuple. The reason behind it is I want to return more than one functions instead of only one. For example, Below is the relay problem I want to implement using python script:

def @main() {
  %2 = fn (%var_0: Tensor[(8), float64]) {
    %0 = cast(%var_0, dtype="float64");
    log(%0)
  };
  %3 = fn (%var_01: Tensor[(7), float64]) {
    %1 = cast(%var_01, dtype="float64");
    log(%1)
  };
  (%2, %3)
}

Then the corresponding python script by my understanding of Relay APIs is

import tvm
from tvm import relay

mod = tvm.IRModule()
var_0 = relay.var("var_0", dtype = "float64", shape = (8,)) # shape=(8,)
var_1 = relay.log(var_0.astype('float64')) # shape=(8,)

func = relay.Function([var_0,], var_1)

var_2 = relay.var("var_2", dtype = "float64", shape = (7,)) # shape=(7,)
var_3 = relay.log(var_2.astype('float64')) # shape=(7,)

func2 = relay.Function([var_2,], var_3)

output = relay.Tuple([func,func2])
F = relay.Function([], output)
mod['main'] = F
#mod = relay.transform.InferType()(mod)
print(mod.astext(show_meta_data=False))
graph, lib, params = relay.build(mod, target='llvm')

And it crashes just as sqchao showed.

Can you try VM compiler and runtime? I don’t think the graph runtime (using relay.build flow) is capable of returning functions.

I am not familiar with graph runtime. According to my experience, relay.build does not support VM(Sorry that I am not very sure). Do you mean creating a graph executor with VM like relay.build_module.create_executor('vm', mod, tvm.device('cuda',0),'cuda')?

BTW, I have just found the following test script in test_analysis_basic_block_normal_form.py:

@pytest.mark.xfail(raises=tvm.error.TVMError)
def test_func():
    x = relay.var("x", shape=(1,), dtype="float32")  # , a)
    y = relay.var("y", shape=(1,), dtype="float32")  # , a)
    z = relay.var("z", shape=(1,), dtype="float32")  # , a)
    x2 = relay.add(x, x)
    func_a = relay.Function([y], relay.add(x2, y))  # , a, [a])
    func_b = relay.Function([z], relay.add(x2, z))  # , a, [a])
    body = relay.Tuple([func_a, func_b])
    body = relay.Function([x], body)
    """
    fn (%x: Tensor[(1), float32]) {
      %1 = fn (%y: Tensor[(1), float32]) {
        %0 = add(%x, %x);
        add(%0, %y)
      };
      %2 = fn (%z: Tensor[(1), float32]) {
        add(%0, %z)
      };
      (%1, %2)
    }
    """
    check_basic_block_normal_form(body)

I think TVM’s original plan includes support returning functions, but something goes wrong in relay.build

1 Like

Yes, relay.build means you are using graph runtime. Yes, relay.build_module.create_executor('vm', mod, ...) is a correct way to use vm.

OK. Thanks for your quick reply. I will have a try.

Thank you @masahi @Haoyang

Following your advice, I change the API relay.build to relay.build_module.create_executor and execute the following 3 statements.

  • relay.build_module.create_executor(‘vm’, mod, tvm.device(‘cuda’,0),‘cuda’)
  • relay.build_module.create_executor(‘debug’, mod, tvm.device(‘cuda’,0),‘cuda’)
  • relay.build_module.create_executor(‘graph’, mod, tvm.device(‘cuda’,0),‘cuda’)

All of the above 3 executors can run well. Therefore, I think the crash is caused by a bug inner relay.build() API.

Wow, much quicker than me Lol.

@masahi Hi, some days past after our last discussion. Do you think it’s a TVM bug? Though I have read some TVM source code, I cannot find out the root cause and maybe need some good suggestions.

I have submitted a pull request in support returned function in relay.build by haoyang9804 · Pull Request #10502 · apache/tvm (github.com). Please take a look.