I have an Expr that I would like to keep track of after applying a pass.
myExpr = relay.const(0)
myFunc = relay.Function([], relay.nn.relu(myExpr))
# The mutator will create new Expr for many expr types, here the Call to relu.
myFuncPost = relay.ExprMutator().visit(myFunc)
# The following does not work, because the Constant in the new Function is a new Expr.
assert myExpr == myFuncPost.body.args[0]
My current workaround is
assert repr(myExpr) == repr(myFuncPost.body.args[0])
with the following issues:
- Generating and comparing these strings gets very computationally expensive for large models.
- There may be collisions. Here, trivially any
Constant(0)
in the model would match.
Is there a better way to do this?
Why do I not just search for my Expr after the pass has been applied? I need the model structure before a pass which would no longer exist after the pass. For example:
- Find an Expr in a fused (
transform.FuseOps
) relay representation based on the structure of the fusings. - Apply
transform.DefuseOps
. - Recover the originally found expression.
- Do transformations that only make sense in the defused relay representation.
Since DefuseOps
changes the structure significantly, I cannot use the workaround from above. Instead I do the following:
- Replace the found expr with
relay.annotation.stop_fusion(myExpr)
to “mark” it. - Apply
DefuseOps
- Search for the annotation and extract its argument:
myExpr = stopFuseExpr.args[0]
- Remove the annotation. This again makes
myExpr
invalid in the latest version of the graph, exactly like in my original problem above. - Use the workaround from above to recover
myExpr
again
Giving the expression some kind of user defined attribute would possibly solve this, but there does not seem to be anything like that.
Do you have any suggestions how to achieve this in a more robust way?