Partitioning a graph with control flow

Hi All,

My goal is to partition a relay (IRModule) based on control flow into relay functions or IRModule as follows:

  1. True branch, condition, the false branch must be separate IRModules or relay functions.
  2. The subsequent expressions following control flow must be separate IRmodule or relay function. For example, in the following example, how can I get the expressions %12 to %16 into a separate relay IRModule?

Let us assume I have the following Relay IR module. It is straight forward to write a relay pass to get the true branch and false branch. I attached the code below. However, I am not able to get the expressions following the control flow into a separate set of relay function (or module)

Relay IR.

def @main(%input: Tensor[(1, 1, 28, 28), float32], %condition: Tensor[(1), bool], %v10: Tensor[(32, 32, 3, 3), float32], %v11: Tensor[(32), float32], %v12: Tensor[(32, 32, 3, 3), float32], %v13: Tensor[(32), float32], %v2: Tensor[(32, 1, 3, 3), float32], %v3: Tensor[(32), float32], %v4: Tensor[(32, 32, 1, 1), float32], %v5: Tensor[(32), float32], %v6: Tensor[(32, 64, 1, 1), float32], %v7: Tensor[(32), float32], %v8: Tensor[(64, 32, 1, 1), float32], %v9: Tensor[(64), float32]) {
  %0 = take(%condition, 0);
  %11 = if (%0) {
    %1 = nn.conv2d(%input, %v2, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
    %2 = nn.bias_add(%1, %v3);
    %3 = nn.relu(%2);
    %4 = nn.conv2d(%3, %v8, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %5 = nn.bias_add(%4, %v9);
    %6 = nn.relu(%5);
    %7 = nn.conv2d(%6, %v6, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %8 = nn.bias_add(%7, %v7);
    nn.relu(%8)
  } else {
    %9 = nn.conv2d(%3, %v4, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
    %10 = nn.bias_add(%9, %v5);
    nn.relu(%10)
  };
  %12 = nn.conv2d(%11, %v10, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %13 = nn.bias_add(%12, %v11);
  %14 = nn.relu(%13);
  %15 = nn.conv2d(%14, %v12, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
  %16 = nn.bias_add(%15, %v13);
  nn.relu(%16)
}

Here is the simple relay pass that extracts true/false branches. Any advice on how to get the relay expressions following control flow (Starting %12 to %16) in separate relay function or IRModule?

class ParitioningPass(ExprMutator):
    """This pass partitions the graph based on control flow

    """
    def __init__(self):
        super().__init__()
        self.inputs = []


    def visit_if(self, ite):
        """This function returns gives graphs aroun the IF ELSE.  (condition, true branch and false 
        branch)

        :param call:
        :return:
        """

        print("If")
        tb = ite.true_branch
        fb = ite.false_branch
        cond = ite.cond