[solved] Segfault on MergeCompositeFunctions

Why am I getting a segfault on MergeCompositeFunctions?

Here is the code:

import torch
import torch.nn as nn
import torch.fx as fx

import tvm
from tvm import relax
import tvm.relax.frontend.torch
from tvm.relax.dpl import is_op, wildcard

class NeuralNetwork(nn.Module):
    def __init__(self):
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.Linear(512, 512),
            nn.Linear(512, 10),

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

network = NeuralNetwork()
input_shape = (1, 1, 28, 28)
input_dtype = "float"

traced_network = fx.symbolic_trace(network)
input_info = [(input_shape, input_dtype)]
mod = relax.frontend.torch.from_fx(traced_network, input_info)

patterns = [
        ("tensorrt.matmul", is_op("relax.matmul")(wildcard(), wildcard())),
mod1 = relax.transform.FuseOpsByPattern(patterns)(mod)
mod2 = relax.transform.MergeCompositeFunctions()(mod1)
# segfault

Sorry about that, I’m taking a look

For now you can use this commit fix · apache/tvm@7150f23 · GitHub

The problem was that the pass doesn’t handle Constant and ShapeExpr nodes properly. It was assuming that all inputs to CallNode are another CallNode.

1 Like
1 Like

Thank you!

20 characters