[Relay] Simplify add(%x, meta[relay.Constant][0])

The recent PR [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr by altanh · Pull Request #7731 · apache/tvm · GitHub introduces new simplify patterns to eliminate identities (e.g., a * 1, a + 0, etc), which is very useful, but it seems not applicable to the result of constant folding. This example is modified from a snippet from Huggingface BERT and converted to Relay:

fn (%x: Tensor[(1, 1, 1, 128), float32]) -> Tensor[(1, 1, 1, 128), float32] {
  %0 = full(1 /* ty=int32 */, shape=[1, 128], dtype="float32") /* ty=Tensor[(1, 128), float32] */;
  %1 = expand_dims(%0, axis=1) /* ty=Tensor[(1, 1, 128), float32] */;
  %2 = expand_dims(%1, axis=2) /* ty=Tensor[(1, 1, 1, 128), float32] */;
  %3 = multiply(1f /* ty=float32 */, %2) /* ty=Tensor[(1, 1, 1, 128), float32] */;
  %4 = subtract(1f /* ty=float32 */, %3) /* ty=Tensor[(1, 1, 1, 128), float32] */;
  %5 = multiply(%4, -10000f /* ty=float32 */) /* ty=Tensor[(1, 1, 1, 128), float32] */;
  add(%x, %5) /* ty=Tensor[(1, 1, 1, 128), float32] */
}

where %5 is the default “attention_mask”. In the case that users don’t provide an attention mask when constructing the model, Huggingface BERT will use ones to create a non-effective mask as a placeholder. We can see that the masks are all zeros after %4, so %5 is also zeros. As a result, the final add becomes +0 and this entire expression can be simplified to just %x.

I first applied constant folding to simplify the expression:

fn (%x: Tensor[(1, 1, 1, 128), float32]) -> Tensor[(1, 1, 1, 128), float32] {
  add(%x, meta[relay.Constant][0] /* ty=Tensor[(1, 1, 1, 128), float32] */) /* ty=Tensor[(1, 1, 1, 128), float32] */
}

It looks great, but now the constant is folded to the constant pool, which makes the pattern is_op("add")(wildcard(), is_constant()) failed to catch, so that the identity elimintation cannot simplify it. Is it possible to look at the value in constant pool during the pattern matching, or is there any other solution?

cc @altanh @mbrookhart @masahi

are constants in the constant pool treated differently from normal Constants? i.e. are they a different kind of Expr? If it’s still a normal Constant Expr then we should still be able to examine the value normally. Note that in this case however, the only way to be sure about it being 0 is to check every element of the tensor, which the simplification pass currently doesn’t do. In an ideal world, we could have a simplification engine (like the stuff @mwillsey was working on) that just has simple rewrites (e.g. full(1) -> ones, subtract(x, x) -> 0, etc) which would compose to catch this subgraph without having to examine values for any intermediate expressions.

They are different. Constant nodes are just constants. You will directly see the constant value on the Relay graph. On the other hand, constant pool is just holding a set of constant tensors. From an op’s point of view, meta[relay.Constant][0] is also an input tensor but just coming from the constant pool instead of model input.

I think to relay, the difference is mainly scalar constant vs tensor constant. For this we could loop over all of the elements of the ndarray to check if they’re all zero. Might be something we need to do…

1 Like