[Performance] can inference optimized away softmax

i saw the comment in nn.softmax

This operator can be optimized away for inference

for now, the bert performance bottleneck is related with softmax. what’s the meaning of this comment,how to optimize away this op. the ir may like below:

%1579 = fn (%p0218: Tensor[(128, 12, 128, 128), float32], Primitive=1, hash="2bf1f4aba825ef91") -> Tensor[(128, 12, 128, 128), float32] {
  nn.softmax(%p0218) /* ty=Tensor[(128, 12, 128, 128), float32] */
};
%1580 = %1579(%1578) /* ty=Tensor[(128, 12, 128, 128), float32] */;
%1581 = fn (%p0217: Tensor[(128, 12, 128, 128), float32], Primitive=1, relay.reshape_only=1, hash="cb182a507ee11eec") -> Tensor[(1536, 128, 128), float32] {
  reshape(%p0217, newshape=[-1, 128, 128]) /* ty=Tensor[(1536, 128, 128), float32] */
};
%1594 = fn (%p0296: Tensor[(16384, 768), int8], %p1234: Tensor[(768, 768), int8], %p2153: Tensor[(16384, 1), int32], %p379: int32, %p479: Tensor[(768), int32], %p553: int32, %p653: float32, Primitive=1, hash="fc3402299aaf823e") -> Tensor[(16384, 768), float32] {
  %1586 = nn.dense(%p0296, %p1234, units=768, out_dtype="int32") /* ty=Tensor[(16384, 768), int32] */;
  %1587 = multiply(%p379, %p479) /* ty=Tensor[(768), int32] */;
  %1588 = subtract(%p553, %1587) /* ty=Tensor[(768), int32] */;
  %1589 = subtract(%1586, %p2153) /* ty=Tensor[(16384, 768), int32] */;
  %1590 = expand_dims(%1588, axis=0) /* ty=Tensor[(1, 768), int32] */;
  %1591 = add(%1589, %1590) /* ty=Tensor[(16384, 768), int32] */;
  %1592 = cast(%1591, dtype="float32") /* ty=Tensor[(16384, 768), float32] */;
  %1593 = multiply(%p653, 0.00167063f /* ty=float32 */) /* ty=float32 */;
  multiply(%1592, %1593) /* ty=Tensor[(16384, 768), float32] */
};
%1595 = %1594(%1545, meta[relay.Constant][33] /* ty=Tensor[(768, 768), int8] */, %1553, %1551, meta[relay.Constant][34] /* ty=Tensor[(768), int32] */, %1554, %1540) /* ty=Tensor[(16384, 768), float32] */;
%1596 = fn (%p0295: Tensor[(16384, 768), float32], %p1233: Tensor[(768), float32], Primitive=1, hash="9fb7058ef5653f52") -> Tensor[(1536, 64, 128), float32] {
  %1582 = reshape(%p0295, newshape=[128, 128, 768]) /* ty=Tensor[(128, 128, 768), float32] */;
  %1583 = add(%1582, %p1233) /* ty=Tensor[(128, 128, 768), float32] */;
  %1584 = reshape(%1583, newshape=[128, 128, 12, 64]) /* ty=Tensor[(128, 128, 12, 64), float32] */;
  %1585 = transpose(%1584, axes=[0, 2, 3, 1]) /* ty=Tensor[(128, 12, 64, 128), float32] */;
  reshape(%1585, newshape=[-1, 64, 128]) /* ty=Tensor[(1536, 64, 128), float32] */
};
%1597 = %1581(%1580) /* ty=Tensor[(1536, 128, 128), float32] */;
%1598 = %1596(%1595, meta[relay.Constant][35] /* ty=Tensor[(768), float32] */) /* ty=Tensor[(1536, 64, 128), float32] */;
%1599 = fn (%p0216: Tensor[(1536, 128, 128), float32], %p1185: Tensor[(1536, 64, 128), float32], Primitive=1, hash="ee1827ff1631f589") -> Tensor[(1536, 128, 64), float32] {
  nn.batch_matmul(%p0216, %p1185, transpose_b=True) /* ty=Tensor[(1536, 128, 64), float32] */
};
%1600 = %1599(%1597, %1598) /* ty=Tensor[(1536, 128, 64), float32] */;

can this condition be optimized?