The recent PR [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr by altanh · Pull Request #7731 · apache/tvm · GitHub introduces new simplify patterns to eliminate identities (e.g., a * 1
, a + 0
, etc), which is very useful, but it seems not applicable to the result of constant folding. This example is modified from a snippet from Huggingface BERT and converted to Relay:
fn (%x: Tensor[(1, 1, 1, 128), float32]) -> Tensor[(1, 1, 1, 128), float32] {
%0 = full(1 /* ty=int32 */, shape=[1, 128], dtype="float32") /* ty=Tensor[(1, 128), float32] */;
%1 = expand_dims(%0, axis=1) /* ty=Tensor[(1, 1, 128), float32] */;
%2 = expand_dims(%1, axis=2) /* ty=Tensor[(1, 1, 1, 128), float32] */;
%3 = multiply(1f /* ty=float32 */, %2) /* ty=Tensor[(1, 1, 1, 128), float32] */;
%4 = subtract(1f /* ty=float32 */, %3) /* ty=Tensor[(1, 1, 1, 128), float32] */;
%5 = multiply(%4, -10000f /* ty=float32 */) /* ty=Tensor[(1, 1, 1, 128), float32] */;
add(%x, %5) /* ty=Tensor[(1, 1, 1, 128), float32] */
}
where %5
is the default “attention_mask”. In the case that users don’t provide an attention mask when constructing the model, Huggingface BERT will use ones
to create a non-effective mask as a placeholder. We can see that the masks are all zeros after %4
, so %5
is also zeros. As a result, the final add
becomes +0
and this entire expression can be simplified to just %x
.
I first applied constant folding to simplify the expression:
fn (%x: Tensor[(1, 1, 1, 128), float32]) -> Tensor[(1, 1, 1, 128), float32] {
add(%x, meta[relay.Constant][0] /* ty=Tensor[(1, 1, 1, 128), float32] */) /* ty=Tensor[(1, 1, 1, 128), float32] */
}
It looks great, but now the constant is folded to the constant pool, which makes the pattern is_op("add")(wildcard(), is_constant())
failed to catch, so that the identity elimintation cannot simplify it. Is it possible to look at the value in constant pool during the pattern matching, or is there any other solution?