Remember Exprs through Passes

I have an Expr that I would like to keep track of after applying a pass.

myExpr = relay.const(0)
myFunc = relay.Function([], relay.nn.relu(myExpr))
# The mutator will create new Expr for many expr types, here the Call to relu.
myFuncPost = relay.ExprMutator().visit(myFunc)
# The following does not work, because the Constant in the new Function is a new Expr.
assert myExpr == myFuncPost.body.args[0]

My current workaround is

assert repr(myExpr) == repr(myFuncPost.body.args[0])

with the following issues:

  • Generating and comparing these strings gets very computationally expensive for large models.
  • There may be collisions. Here, trivially any Constant(0) in the model would match.

Is there a better way to do this?


Why do I not just search for my Expr after the pass has been applied? I need the model structure before a pass which would no longer exist after the pass. For example:

  • Find an Expr in a fused (transform.FuseOps) relay representation based on the structure of the fusings.
  • Apply transform.DefuseOps.
  • Recover the originally found expression.
  • Do transformations that only make sense in the defused relay representation.

Since DefuseOps changes the structure significantly, I cannot use the workaround from above. Instead I do the following:

  • Replace the found expr with relay.annotation.stop_fusion(myExpr) to “mark” it.
  • Apply DefuseOps
  • Search for the annotation and extract its argument: myExpr = stopFuseExpr.args[0]
  • Remove the annotation. This again makes myExpr invalid in the latest version of the graph, exactly like in my original problem above.
  • Use the workaround from above to recover myExpr again

Giving the expression some kind of user defined attribute would possibly solve this, but there does not seem to be anything like that.

Do you have any suggestions how to achieve this in a more robust way?

Expecting assert myExpr == myFuncPost.body.args[0] does’t make sense to me, and assert repr(myExpr) == repr(myFuncPost.body.args[0]) is not a robust workaround neither. Considering the following case:

myExpr1 = relay.const(0)
myExpr2 = relay.const(0)
myFunc = relay.Function([], myExpr1 + myExpr2)
myFuncPost = relay.ExprMutator().visit(myFunc)
assert repr(myExpr1) == repr(myFuncPost.body.args[0]) # hold
assert repr(myExpr1) == repr(myFuncPost.body.args[1]) # hold, too

Note that you’ve applied a mutation pass, which supposes to “mutate” the IR, meaning that the output IR should have no connection to the input IR in terms of the node address. If you attempt to keep some fused functions in the DeFuseOps pass, you should try to improve the DeFuseOps pass to ignore them.