Maybe a bit more context will help: my model is a LSTM with an input sequence length of 25. Looking at the relay function, it seems that the model weights are already duplicated there, after unrolling the LSTM loop. There are exactly 25 references to a (20, 7) tensor which contains the weights (I assume), each one with a different index in the relay.Constant list. Each one of these seems to be allocated separately, so that the generated code ends up using much more storage than necessary.
Has anyone seen something similar when importing LSTM models? Is this to be expected? Any hints would be appreciated.
Here is a part of the relay function, where you can see the same constant being used in %6, %27 and %48:
def @main(%seq: Tensor[(25, 2), float32] /* ty=Tensor[(25, 2), float32] span=/Reshape.seq:0:0 */) ->
Tensor[(1), float32] {
%0 = reshape(%seq, newshape=[25, 1, -1]) /* ty=Tensor[(25, 1, 2), float32] span=/Reshape:0:0 */;
%1 = split(%0, indices_or_sections=25) /* ty=(Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32],
Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2),
float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1,
2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1,
1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32],
Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2),
float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1, 2), float32], Tensor[(1, 1,
2), float32]) span=/lstm/LSTM:0:0 */;
%2 = %1.0 /* ty=Tensor[(1, 1, 2), float32] span=/lstm/LSTM:0:0 */;
%3 = squeeze(%2, axis=[0]) /* ty=Tensor[(1, 2), float32] span=/lstm/LSTM:0:0 */;
%4 = (%3, meta[relay.Constant][0] /* ty=Tensor[(1, 5), float32] span=/lstm/LSTM:0:0 */) /* ty=
(Tensor[(1, 2), float32], Tensor[(1, 5), float32]) span=/lstm/LSTM:0:0 */;
%5 = concatenate(%4, axis=1) /* ty=Tensor[(1, 7), float32] span=/lstm/LSTM:0:0 */;
%6 = nn.dense(%5, meta[relay.Constant][1] /* ty=Tensor[(20, 7), float32] span=/lstm/LSTM:0:0 */, units=None) /* ty=Tensor[(1, 20), float32] span=/lstm/LSTM:0:0 */;
...
%27 = nn.dense(%26, meta[relay.Constant][5] /* ty=Tensor[(20, 7), float32] span=/lstm/LSTM:0:0 */,
units=None) /* ty=Tensor[(1, 20), float32] span=/lstm/LSTM:0:0 */;
...
%48 = nn.dense(%47, meta[relay.Constant][6] /* ty=Tensor[(20, 7), float32] span=/lstm/LSTM:0:0 */,
units=None) /* ty=Tensor[(1, 20), float32] span=/lstm/LSTM:0:0 */;
...