Remove duplicated operators

model parsed from pytorch may contain many duplicated operators. Is there a pass which can remove duplicated operators? e.g.

%1 = relay.take(%0, 0, axis=1, model="wrap");
%2 = relay.take(%0, 0, axis=1, model="wrap");
%3 = relay.multiply(%1, %2);

would like to optimize it to be:

%1 = relay.take(%0, 0, axis=1, model="wrap");
%3 = relay.multiply(%1, %1);

I find that CommonSubexprEliminator pass may be exact what I needed.

1 Like