Failed to build the GRU model from relay. stack overflow occured

I test a simple pytorch GRU model, and when I built the converted relay model, a stack overflow error occured.

It may be a bug within the TVM, and how to fix it?

Can you share your complete script?

Just this simple script.

I mean, can you paste your code so that I can run it.

import torch
import torch.nn as nn
import os
os.environ["path"] = os.getenv("path") + ";" + "C:/Program Files/LLVM/bin"
import sys
sys.path.append("D:/workspace/tvm/python")
import tvm
from tvm import relay
from tvm.contrib import graph_executor


class GruTest(nn.Module):
    def __init__(self):
        super(GruTest, self).__init__()
        self.rnn = nn.GRU(100, 200, bidirectional=False)
        #self.embed = nn.Embedding(2, 100)

    def forward(self, input):
        #tmp = self.embed(input)
        out, hidden_state = self.rnn(input)
        return out


if __name__ == "__main__":
    model = GruTest()
    #tmp = torch.ones(500, 3, dtype=torch.int32)
    tmp = torch.randn(500, 3, 100)
    traced_model = torch.jit.trace(model, (tmp,))
    shape_list = [(i.debugName().split('.')[0],
                   (i.type().sizes(), str(i.type().dtype()).split('.')[1])) for i in
                  list(traced_model.graph.inputs())[1:]]
    mod, params = tvm.relay.frontend.pytorch.from_pytorch(traced_model, shape_list)
    target = "llvm"
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)
    print(mod)`Preformatted text`

I see. The output of your model looks like

  %12508 = (%1034, %1057, %1080, %1103, %1126, %1149, %1172, %1195, %1218, %1241, %1264, %1287, %1310, %1333, %1356, %1379, %1402, %1425, %1448, %1471, %1494, %1517, %1540, %1563, %1586, %1609, %1632, %1655, %1678, %1701, %1724, %1747, %1770, %1793, %1816, %1839, %1862, %1885, %1908, %1931, %1954, %1977, %2000, %2023, %2046, %2069, %2092, %2115, %2138, %2161, %2184, %2207, %2230, %2253, %2276, %2299, %2322, %2345, %2368, %2391, %2414, %2437, %2460, %2483, %2506, %2529, %2552, %2575, %2598, %2621, %2644, %2667, %2690, %2713, %2736, %2759, %2782, %2805, %2828, %2851, %2874, %2897, %2920, %2943, %2966, %2989, %3012, %3035, %3058, %3081, %3104, %3127, %3150, %3173, %3196, %3219, %3242, %3265, %3288, %3311, %3334, %3357, %3380, %3403, %3426, %3449, %3472, %3495, %3518, %3541, %3564, %3587, %3610, %3633, %3656, %3679, %3702, %3725, %3748, %3771, %3794, %3817, %3840, %3863, %3886, %3909, %3932, %3955, %3978, %4001, %4024, %4047, %4070, %4093, %4116, %4139, %4162, %4185, %4208, %4231, %4254, %4277, %4300, %4323, %4346, %4369, %4392, %4415, %4438, %4461, %4484, %4507, %4530, %4553, %4576, %4599, %4622, %4645, %4668, %4691, %4714, %4737, %4760, %4783, %4806, %4829, %4852, %4875, %4898, %4921, %4944, %4967, %4990, %5013, %5036, %5059, %5082, %5105, %5128, %5151, %5174, %5197, %5220, %5243, %5266, %5289, %5312, %5335, %5358, %5381, %5404, %5427, %5450, %5473, %5496, %5519, %5542, %5565, %5588, %5611, %5634, %5657, %5680, %5703, %5726, %5749, %5772, %5795, %5818, %5841, %5864, %5887, %5910, %5933, %5956, %5979, %6002, %6025, %6048, %6071, %6094, %6117, %6140, %6163, %6186, %6209, %6232, %6255, %6278, %6301, %6324, %6347, %6370, %6393, %6416, %6439, %6462, %6485, %6508, %6531, %6554, %6577, %6600, %6623, %6646, %6669, %6692, %6715, %6738, %6761, %6784, %6807, %6830, %6853, %6876, %6899, %6922, %6945, %6968, %6991, %7014, %7037, %7060, %7083, %7106, %7129, %7152, %7175, %7198, %7221, %7244, %7267, %7290, %7313, %7336, %7359, %7382, %7405, %7428, %7451, %7474, %7497, %7520, %7543, %7566, %7589, %7612, %7635, %7658, %7681, %7704, %7727, %7750, %7773, %7796, %7819, %7842, %7865, %7888, %7911, %7934, %7957, %7980, %8003, %8026, %8049, %8072, %8095, %8118, %8141, %8164, %8187, %8210, %8233, %8256, %8279, %8302, %8325, %8348, %8371, %8394, %8417, %8440, %8463, %8486, %8509, %8532, %8555, %8578, %8601, %8624, %8647, %8670, %8693, %8716, %8739, %8762, %8785, %8808, %8831, %8854, %8877, %8900, %8923, %8946, %8969, %8992, %9015, %9038, %9061, %9084, %9107, %9130, %9153, %9176, %9199, %9222, %9245, %9268, %9291, %9314, %9337, %9360, %9383, %9406, %9429, %9452, %9475, %9498, %9521, %9544, %9567, %9590, %9613, %9636, %9659, %9682, %9705, %9728, %9751, %9774, %9797, %9820, %9843, %9866, %9889, %9912, %9935, %9958, %9981, %10004, %10027, %10050, %10073, %10096, %10119, %10142, %10165, %10188, %10211, %10234, %10257, %10280, %10303, %10326, %10349, %10372, %10395, %10418, %10441, %10464, %10487, %10510, %10533, %10556, %10579, %10602, %10625, %10648, %10671, %10694, %10717, %10740, %10763, %10786, %10809, %10832, %10855, %10878, %10901, %10924, %10947, %10970, %10993, %11016, %11039, %11062, %11085, %11108, %11131, %11154, %11177, %11200, %11223, %11246, %11269, %11292, %11315, %11338, %11361, %11384, %11407, %11430, %11453, %11476, %11499, %11522, %11545, %11568, %11591, %11614, %11637, %11660, %11683, %11706, %11729, %11752, %11775, %11798, %11821, %11844, %11867, %11890, %11913, %11936, %11959, %11982, %12005, %12028, %12051, %12074, %12097, %12120, %12143, %12166, %12189, %12212, %12235, %12258, %12281, %12304, %12327, %12350, %12373, %12396, %12419, %12442, %12465, %12488, %12507);
  stack(%12508)
}

Our PyTorch converter unrolls all layers in GRU, so we end up with a huge graph like yours. So we either hit a stack overflow error or extremely long compile time.

Unfortunately, this is not easy to solve currently. Relax might have a better story for RNN / LSTM type of ops. @yuchenj

@yongwww has added support for compiling and running LSTM model on Relax, and now he is working on improving LSTM performance.

The current while_loop support in Relax is through recursion (similar to Relay), but we plan to add imperative intrinsics such as while and foreach in the future which should be able to solve the stack overflow issue caused by recursion, stay tuned! Here is the development plan: https://github.com/tlc-pack/relax/issues/79.