Hi All,
I am wondering how can I able to extract relay expressions which follow IF statement in Relay? I gave an example below. The expressions that follow the IF statement will vary, so it will make it difficult to use Relay pattern language to partition it.
Nevertheless, I have at least two solutions which I am not sure if they are the efficient way to extract these expressions/
Here are my current approaches:
We can use relay pattern language (hardcoded) ==> I tested it it works.
We can use a relay dominator pattern (it is better than 1 because as long as we know the start and end expressions, we can create a pattern). I have not tested it yet.
I am wondering about other efficient solutions? May be traversing AST and find out the expressions following IF statment?
I figured out that if I was able to insert annotations into the part (in the above Figure BLUE part), I could partition the expressions into different functions. However, I am not able to insert annotations to the exact location so far. My goal is to insert annotations into the region that immediately follows IF statement. Any ideas @mbrookhart@masahi@rkimball ?
BTW, I could register all the ops in the current example as follows and able annotate the given Relay IR. However, this approach does not work for my real example. Thus, I needed a way to insert annotations into specific locations for some reason. I know theop name and op parameters.
Oh wait, I misread on the first pass, that would allow you to get the body of the If, but it looks like you want everything after the If? Is there anything special about what comes after, or just that it follows an If?
I’d like to get everything (all relay expressions) after IF.
There is no special thing after if. What comes after IF are regular relay expressions. So, all I want is to extract everything after IF into the relay function.
Infact, I already tried it. This is another approach I am exploring. I can partition for the given (above) example because the example is simple. However, the partitioned failed for complex ones. I can post more details if other approaches fails.
How about a dominator, with the parent being the if, a wildcard for both the body and the child, and you use the rewrite api because you don’t want to partition the If?
Actually, I have tried using dominator, but I have not used “IF” as parent yet, but I will try it. Actually, could you please elaborate that little bit?
I just want to make sure that I made my problem clear. In this particular case, I am looking to partition %12 all the way to the end (which is after %16). Specifically, I am looking to extract following expressions.
I poked around at this for a while this morning, I think your requirements might be a little too fuzzy for the pattern matcher. I think I could do it, but it requries editing the partitioner to ignore fuzzy if patterns like I did to the Function.
If you only expect to have 1 If in the graph, you could do this with a pass: Traverse the graph, store the pre-order topological (reverse) sort in state in your pass, and when you hit an if, stop traversal. When you unwind the iteration, your topological sort only contains those nodes that follow the If.
I’m not sure I understand the use, though. There might be a better solution if you can share what you’re trying to do at a slightly higher level?
Thank you very much. I have created an example (relay program) to explain what I am trying to do.
On a high level, I’d like to partition the graph (relay graph) based on control flow. That means I need to create one partition for the true branch(see inlined comment in the code: graph 1), one partition for the false branch (see inlined comment in the code: graph 2), and one partition for the relay expressions that follow If (see inlined comment in the code: graph 3).
Here is an example, and please see the comments in the code. It is straightforward to extract true and false branches, but I am stuck in creating a partition for graph 3. I know the start and end node in graph 3, but in between, there can be any number of nodes. Seems like the dominator pattern is a good fit, but I was not successful using the dominator pattern (as you mentioned).
# graph 1: Extract true branch
x1 = relay.var('x', shape=(1, 1, 28, 28))
y1 = relay.var('y', shape=(1, 1, 28, 28))
f1 = relay.op.add(x1, y1)
# graph 2: Extract false branch
x2 = relay.var('x', shape=(1, 1, 28, 28))
y2 = relay.var('y', shape=(1, 1, 28, 28))
f2 = relay.op.multiply(x2, y2)
cond = relay.var("c", shape=(1, 1, 1, 1), dtype='uint8')
result = relay.If(cond, true_branch=f1, false_branch=f2)
f = relay.Function([], result)
# graph 3: Extract branch that follows IF
# I know the starting NODE (Operator) and END operator.
# I was under the assumption that the fuzzy or dominator should work. I was not able to successfully
# use the dominator pattern yet for this example.
data = relay.var('input', shape=(1, 1, 28, 28))
weight = relay.var('weight', shape=(32, 1, 3, 3))
bias = relay.var('bias', shape=(32,))
conv2d = relay.op.nn.conv2d(f, weight, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3])
bias_add = relay.op.nn.bias_add(conv2d, bias);
relu = relay.op.nn.relu(bias_add)
conv2d = relay.op.nn.conv2d(relu, weight, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3])
bias_add = relay.op.nn.bias_add(conv2d, bias);
relu = relay.op.nn.relu(bias_add)
I have made some attempts as you suggested (doing reverse topological order and getting the part after IF as you suggested). Here is my understanding of your suggestions which I also agree with. Please let me know if I got them wrong. Here are the steps that I assumed that I should follow to extract the parts after IF:
To reverse traversal by reverse top sort
when I hit IF, I know that I traversed the parts that I need to extract.
I was not able to traverse the relay expressions in reverse order. Could you please elaborate on this?
I have tried to use ExprMutator and relay.analysis.post_order_visit to do reverse topological order sort.