How to extract Relay expressions following If statement

Hi All, I am wondering how can I able to extract relay expressions which follow IF statement in Relay? I gave an example below. The expressions that follow the IF statement will vary, so it will make it difficult to use Relay pattern language to partition it.

Nevertheless, I have at least two solutions which I am not sure if they are the efficient way to extract these expressions/

Here are my current approaches:

  1. We can use relay pattern language (hardcoded) ==> I tested it it works.
  2. We can use a relay dominator pattern (it is better than 1 because as long as we know the start and end expressions, we can create a pattern). I have not tested it yet.
  3. I am wondering about other efficient solutions? May be traversing AST and find out the expressions following IF statment?

Hi All,

I figured out that if I was able to insert annotations into the part (in the above Figure BLUE part), I could partition the expressions into different functions. However, I am not able to insert annotations to the exact location so far. My goal is to insert annotations into the region that immediately follows IF statement. Any ideas @mbrookhart @masahi @rkimball ?

def @main(%input: Tensor[(1, 1, 28, 28), float32], %condition: Tensor[(1), bool], %v10: Tensor[(32, 32, 3, 3), float32], %v11: Tensor[(32), float32], %v12: Tensor[(32, 32, 3, 3), float32], %v13: Tensor[(32), float32], %v2: Tensor[(32, 1, 3, 3), float32], %v3: Tensor[(32), float32], %v4: Tensor[(32, 32, 1, 1), float32], %v5: Tensor[(32), float32], %v6: Tensor[(32, 64, 1, 1), float32], %v7: Tensor[(32), float32], %v8: Tensor[(64, 32, 1, 1), float32], %v9: Tensor[(64), float32])
  %0 = take(%condition, 0);
  %11 = if (%0) {
    %1 = nn.conv2d(%input, %v2, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
    %2 = nn.bias_add(%1, %v3);
    %3 = nn.relu(%2);
    %4 = nn.conv2d(%3, %v8, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %5 = nn.bias_add(%4, %v9);
    %6 = nn.relu(%5);
    %7 = nn.conv2d(%6, %v6, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %8 = nn.bias_add(%7, %v7);
    nn.relu(%8)
  } else {
    %9 = nn.conv2d(%3, %v4, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %10 = nn.bias_add(%9, %v5);
    nn.relu(%10)
  };
  %12 = nn.conv2d(%11, %v10, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %13 = nn.bias_add(%12, %v11);
  %14 = nn.relu(%13);
  %15 = nn.conv2d(%14, %v12, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %16 = nn.bias_add(%15, %v13);
  nn.relu(%16)
}

BTW, I could register all the ops in the current example as follows and able annotate the given Relay IR. However, this approach does not work for my real example. Thus, I needed a way to insert annotations into specific locations for some reason. I know theop name and op parameters.

def _register_external_op_helper(op_name, supported=True):

    @tvm.ir.register_op_attr(op_name, "target.special")
    def _func_wrapper(expr):
        return supported

    return _func_wrapper

_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.biasadd")
_register_external_op_helper("nn.relu")

Hmm, I did something like this with functions recently:

I wasn’t able to get it to work with partitioning, but it did work with the rewrite api. Perhaps we could extend that to IF?

Oh wait, I misread on the first pass, that would allow you to get the body of the If, but it looks like you want everything after the If? Is there anything special about what comes after, or just that it follows an If?

1 Like

I’d like to get everything (all relay expressions) after IF. There is no special thing after if. What comes after IF are regular relay expressions. So, all I want is to extract everything after IF into the relay function.

Infact, I already tried it. This is another approach I am exploring. I can partition for the given (above) example because the example is simple. However, the partitioned failed for complex ones. I can post more details if other approaches fails.

How about a dominator, with the parent being the if, a wildcard for both the body and the child, and you use the rewrite api because you don’t want to partition the If?

Actually, I have tried using dominator, but I have not used “IF” as parent yet, but I will try it. Actually, could you please elaborate that little bit?

Here is what I have done so far:

def create_pattern_for_mnist_tvm_dominator():
    """A simple pattern (dominator) creator. 
    """

    # pattern: conv2d + bias_add
    pattern0 = is_op("nn.conv2d")(wildcard(), wildcard())

    pattern1 = is_op("nn.bias_add")(pattern0, wildcard())

    parent= is_op("nn.relu")(pattern1)

    anything_in_between = (wildcard())


    child= is_op("nn.relu")(wildcard())
    dominator = dominates(parent, anything_in_between, child)

    return dominator 

#Create dominator pattern 
dominator_pattern = create_pattern_for_mnist_tvm_dominator()
partitioned_expression = dominator_pattern .partition(module['main'], {'name': 'subgraph'})


This does not partition the expressions after “IF”. In fact, it just returns the orignal expressions as follows:

print(partitioned_expression)
fn (%input: Tensor[(1, 1, 28, 28), float32], %condition: Tensor[(1), bool], %v10: Tensor[(32, 32, 3, 3), float32], %v11: Tensor[(32), float32], %v12: Tensor[(32, 32, 3, 3), float32], %v13: Tensor[(32), float32], %v2: Tensor[(32, 1, 3, 3), float32], %v3: Tensor[(32), float32], %v4: Tensor[(32, 32, 1, 1), float32], %v5: Tensor[(32), float32], %v6: Tensor[(32, 64, 1, 1), float32], %v7: Tensor[(32), float32], %v8: Tensor[(64, 32, 1, 1), float32], %v9: Tensor[(64), float32]) {
  %0 = take(%condition, 0);
  %11 = if (%0) {
    %1 = nn.conv2d(%input, %v2, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
    %2 = nn.bias_add(%1, %v3);
    %3 = nn.relu(%2);
    %4 = nn.conv2d(%3, %v8, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %5 = nn.bias_add(%4, %v9);
    %6 = nn.relu(%5);
    %7 = nn.conv2d(%6, %v6, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %8 = nn.bias_add(%7, %v7);
    nn.relu(%8)
  } else {
    %9 = nn.conv2d(%3, %v4, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %10 = nn.bias_add(%9, %v5);
    nn.relu(%10)
  };
  %12 = nn.conv2d(%11, %v10, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %13 = nn.bias_add(%12, %v11);
  %14 = nn.relu(%13);
  %15 = nn.conv2d(%14, %v12, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %16 = nn.bias_add(%15, %v13);
  nn.relu(%16)
}

One more thing,

I just want to make sure that I made my problem clear. In this particular case, I am looking to partition %12 all the way to the end (which is after %16). Specifically, I am looking to extract following expressions.

  %12 = nn.conv2d(%11, %v10, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %13 = nn.bias_add(%12, %v11);
  %14 = nn.relu(%13);
  %15 = nn.conv2d(%14, %v12, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %16 = nn.bias_add(%15, %v13);
  nn.relu(%16)

I poked around at this for a while this morning, I think your requirements might be a little too fuzzy for the pattern matcher. I think I could do it, but it requries editing the partitioner to ignore fuzzy if patterns like I did to the Function.

If you only expect to have 1 If in the graph, you could do this with a pass: Traverse the graph, store the pre-order topological (reverse) sort in state in your pass, and when you hit an if, stop traversal. When you unwind the iteration, your topological sort only contains those nodes that follow the If.

I’m not sure I understand the use, though. There might be a better solution if you can share what you’re trying to do at a slightly higher level?

Hi @mbrookhart,

Thank you very much. I have created an example (relay program) to explain what I am trying to do.

On a high level, I’d like to partition the graph (relay graph) based on control flow. That means I need to create one partition for the true branch(see inlined comment in the code: graph 1), one partition for the false branch (see inlined comment in the code: graph 2), and one partition for the relay expressions that follow If (see inlined comment in the code: graph 3).

Here is an example, and please see the comments in the code. It is straightforward to extract true and false branches, but I am stuck in creating a partition for graph 3. I know the start and end node in graph 3, but in between, there can be any number of nodes. Seems like the dominator pattern is a good fit, but I was not successful using the dominator pattern (as you mentioned).

    # graph 1: Extract true branch
    x1 = relay.var('x', shape=(1, 1, 28, 28))
    y1 = relay.var('y', shape=(1, 1, 28, 28))
    f1 = relay.op.add(x1, y1)
    
    # graph 2: Extract false branch
    x2 = relay.var('x', shape=(1, 1, 28, 28))
    y2 = relay.var('y', shape=(1, 1, 28, 28))

   
    f2 = relay.op.multiply(x2, y2)

    cond = relay.var("c", shape=(1, 1, 1, 1), dtype='uint8')
    result = relay.If(cond, true_branch=f1, false_branch=f2)
    f = relay.Function([], result)

    # graph 3: Extract branch that follows IF
    # I know the starting NODE (Operator) and END operator. 
    # I was under the assumption that the fuzzy or dominator should work. I was not able to successfully
    # use the dominator pattern yet for this example. 
    data = relay.var('input', shape=(1, 1, 28, 28))
    weight = relay.var('weight', shape=(32, 1, 3, 3))
    bias = relay.var('bias', shape=(32,))

    conv2d = relay.op.nn.conv2d(f, weight, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3])
    bias_add = relay.op.nn.bias_add(conv2d, bias);
    relu = relay.op.nn.relu(bias_add)

    conv2d = relay.op.nn.conv2d(relu, weight, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3])
    bias_add = relay.op.nn.bias_add(conv2d, bias);
    relu = relay.op.nn.relu(bias_add)

Created following in case if it helps to explain the problem to explain graph 3.
What we know about graph 3 are as follows:

  • List item it follows IF

  • List item it consumes result of IF (one of the inputs are either from true branch or false branch)

  • List item We know the start and end node.

Hi @mbrookhart,

I have made some attempts as you suggested (doing reverse topological order and getting the part after IF as you suggested). Here is my understanding of your suggestions which I also agree with. Please let me know if I got them wrong. Here are the steps that I assumed that I should follow to extract the parts after IF:

  1. To reverse traversal by reverse top sort
  2. when I hit IF, I know that I traversed the parts that I need to extract.

I was not able to traverse the relay expressions in reverse order. Could you please elaborate on this?

I have tried to use ExprMutator and relay.analysis.post_order_visit to do reverse topological order sort.