Why am I getting a segfault on MergeCompositeFunctions
?
Here is the code:
import torch
import torch.nn as nn
import torch.fx as fx
import tvm
from tvm import relax
import tvm.relax.frontend.torch
from tvm.relax.dpl import is_op, wildcard
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
network = NeuralNetwork()
input_shape = (1, 1, 28, 28)
input_dtype = "float"
traced_network = fx.symbolic_trace(network)
input_info = [(input_shape, input_dtype)]
mod = relax.frontend.torch.from_fx(traced_network, input_info)
patterns = [
("tensorrt.matmul", is_op("relax.matmul")(wildcard(), wildcard())),
]
mod1 = relax.transform.FuseOpsByPattern(patterns)(mod)
mod2 = relax.transform.MergeCompositeFunctions()(mod1)
# segfault