I test a simple pytorch GRU model, and when I built the converted relay model, a stack overflow error occured.
It may be a bug within the TVM, and how to fix it?Can you share your complete script?
I mean, can you paste your code so that I can run it.
import torch
import torch.nn as nn
import os
os.environ["path"] = os.getenv("path") + ";" + "C:/Program Files/LLVM/bin"
import sys
sys.path.append("D:/workspace/tvm/python")
import tvm
from tvm import relay
from tvm.contrib import graph_executor
class GruTest(nn.Module):
def __init__(self):
super(GruTest, self).__init__()
self.rnn = nn.GRU(100, 200, bidirectional=False)
#self.embed = nn.Embedding(2, 100)
def forward(self, input):
#tmp = self.embed(input)
out, hidden_state = self.rnn(input)
return out
if __name__ == "__main__":
model = GruTest()
#tmp = torch.ones(500, 3, dtype=torch.int32)
tmp = torch.randn(500, 3, 100)
traced_model = torch.jit.trace(model, (tmp,))
shape_list = [(i.debugName().split('.')[0],
(i.type().sizes(), str(i.type().dtype()).split('.')[1])) for i in
list(traced_model.graph.inputs())[1:]]
mod, params = tvm.relay.frontend.pytorch.from_pytorch(traced_model, shape_list)
target = "llvm"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
print(mod)`Preformatted text`
I see. The output of your model looks like
%12508 = (%1034, %1057, %1080, %1103, %1126, %1149, %1172, %1195, %1218, %1241, %1264, %1287, %1310, %1333, %1356, %1379, %1402, %1425, %1448, %1471, %1494, %1517, %1540, %1563, %1586, %1609, %1632, %1655, %1678, %1701, %1724, %1747, %1770, %1793, %1816, %1839, %1862, %1885, %1908, %1931, %1954, %1977, %2000, %2023, %2046, %2069, %2092, %2115, %2138, %2161, %2184, %2207, %2230, %2253, %2276, %2299, %2322, %2345, %2368, %2391, %2414, %2437, %2460, %2483, %2506, %2529, %2552, %2575, %2598, %2621, %2644, %2667, %2690, %2713, %2736, %2759, %2782, %2805, %2828, %2851, %2874, %2897, %2920, %2943, %2966, %2989, %3012, %3035, %3058, %3081, %3104, %3127, %3150, %3173, %3196, %3219, %3242, %3265, %3288, %3311, %3334, %3357, %3380, %3403, %3426, %3449, %3472, %3495, %3518, %3541, %3564, %3587, %3610, %3633, %3656, %3679, %3702, %3725, %3748, %3771, %3794, %3817, %3840, %3863, %3886, %3909, %3932, %3955, %3978, %4001, %4024, %4047, %4070, %4093, %4116, %4139, %4162, %4185, %4208, %4231, %4254, %4277, %4300, %4323, %4346, %4369, %4392, %4415, %4438, %4461, %4484, %4507, %4530, %4553, %4576, %4599, %4622, %4645, %4668, %4691, %4714, %4737, %4760, %4783, %4806, %4829, %4852, %4875, %4898, %4921, %4944, %4967, %4990, %5013, %5036, %5059, %5082, %5105, %5128, %5151, %5174, %5197, %5220, %5243, %5266, %5289, %5312, %5335, %5358, %5381, %5404, %5427, %5450, %5473, %5496, %5519, %5542, %5565, %5588, %5611, %5634, %5657, %5680, %5703, %5726, %5749, %5772, %5795, %5818, %5841, %5864, %5887, %5910, %5933, %5956, %5979, %6002, %6025, %6048, %6071, %6094, %6117, %6140, %6163, %6186, %6209, %6232, %6255, %6278, %6301, %6324, %6347, %6370, %6393, %6416, %6439, %6462, %6485, %6508, %6531, %6554, %6577, %6600, %6623, %6646, %6669, %6692, %6715, %6738, %6761, %6784, %6807, %6830, %6853, %6876, %6899, %6922, %6945, %6968, %6991, %7014, %7037, %7060, %7083, %7106, %7129, %7152, %7175, %7198, %7221, %7244, %7267, %7290, %7313, %7336, %7359, %7382, %7405, %7428, %7451, %7474, %7497, %7520, %7543, %7566, %7589, %7612, %7635, %7658, %7681, %7704, %7727, %7750, %7773, %7796, %7819, %7842, %7865, %7888, %7911, %7934, %7957, %7980, %8003, %8026, %8049, %8072, %8095, %8118, %8141, %8164, %8187, %8210, %8233, %8256, %8279, %8302, %8325, %8348, %8371, %8394, %8417, %8440, %8463, %8486, %8509, %8532, %8555, %8578, %8601, %8624, %8647, %8670, %8693, %8716, %8739, %8762, %8785, %8808, %8831, %8854, %8877, %8900, %8923, %8946, %8969, %8992, %9015, %9038, %9061, %9084, %9107, %9130, %9153, %9176, %9199, %9222, %9245, %9268, %9291, %9314, %9337, %9360, %9383, %9406, %9429, %9452, %9475, %9498, %9521, %9544, %9567, %9590, %9613, %9636, %9659, %9682, %9705, %9728, %9751, %9774, %9797, %9820, %9843, %9866, %9889, %9912, %9935, %9958, %9981, %10004, %10027, %10050, %10073, %10096, %10119, %10142, %10165, %10188, %10211, %10234, %10257, %10280, %10303, %10326, %10349, %10372, %10395, %10418, %10441, %10464, %10487, %10510, %10533, %10556, %10579, %10602, %10625, %10648, %10671, %10694, %10717, %10740, %10763, %10786, %10809, %10832, %10855, %10878, %10901, %10924, %10947, %10970, %10993, %11016, %11039, %11062, %11085, %11108, %11131, %11154, %11177, %11200, %11223, %11246, %11269, %11292, %11315, %11338, %11361, %11384, %11407, %11430, %11453, %11476, %11499, %11522, %11545, %11568, %11591, %11614, %11637, %11660, %11683, %11706, %11729, %11752, %11775, %11798, %11821, %11844, %11867, %11890, %11913, %11936, %11959, %11982, %12005, %12028, %12051, %12074, %12097, %12120, %12143, %12166, %12189, %12212, %12235, %12258, %12281, %12304, %12327, %12350, %12373, %12396, %12419, %12442, %12465, %12488, %12507);
stack(%12508)
}
Our PyTorch converter unrolls all layers in GRU, so we end up with a huge graph like yours. So we either hit a stack overflow error or extremely long compile time.
Unfortunately, this is not easy to solve currently. Relax might have a better story for RNN / LSTM type of ops. @yuchenj
@yongwww has added support for compiling and running LSTM model on Relax, and now he is working on improving LSTM performance.
The current while_loop
support in Relax is through recursion (similar to Relay), but we plan to add imperative intrinsics such as while
and foreach
in the future which should be able to solve the stack overflow issue caused by recursion, stay tuned! Here is the development plan: https://github.com/tlc-pack/relax/issues/79.