Hi All,
My goal is to partition a relay (IRModule) based on control flow into relay functions or IRModule as follows:
- True branch, condition, the false branch must be separate IRModules or relay functions.
- The subsequent expressions following control flow must be separate IRmodule or relay function. For example, in the following example, how can I get the expressions %12 to %16 into a separate relay IRModule?
Let us assume I have the following Relay IR module. It is straight forward to write a relay pass to get the true branch and false branch. I attached the code below. However, I am not able to get the expressions following the control flow into a separate set of relay function (or module)
Relay IR.
def @main(%input: Tensor[(1, 1, 28, 28), float32], %condition: Tensor[(1), bool], %v10: Tensor[(32, 32, 3, 3), float32], %v11: Tensor[(32), float32], %v12: Tensor[(32, 32, 3, 3), float32], %v13: Tensor[(32), float32], %v2: Tensor[(32, 1, 3, 3), float32], %v3: Tensor[(32), float32], %v4: Tensor[(32, 32, 1, 1), float32], %v5: Tensor[(32), float32], %v6: Tensor[(32, 64, 1, 1), float32], %v7: Tensor[(32), float32], %v8: Tensor[(64, 32, 1, 1), float32], %v9: Tensor[(64), float32]) {
%0 = take(%condition, 0);
%11 = if (%0) {
%1 = nn.conv2d(%input, %v2, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
%2 = nn.bias_add(%1, %v3);
%3 = nn.relu(%2);
%4 = nn.conv2d(%3, %v8, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
%5 = nn.bias_add(%4, %v9);
%6 = nn.relu(%5);
%7 = nn.conv2d(%6, %v6, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
%8 = nn.bias_add(%7, %v7);
nn.relu(%8)
} else {
%9 = nn.conv2d(%3, %v4, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
%10 = nn.bias_add(%9, %v5);
nn.relu(%10)
};
%12 = nn.conv2d(%11, %v10, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
%13 = nn.bias_add(%12, %v11);
%14 = nn.relu(%13);
%15 = nn.conv2d(%14, %v12, strides=[3, 3], padding=[0, 0, 0, 0], kernel_size=[3, 3]);
%16 = nn.bias_add(%15, %v13);
nn.relu(%16)
}
Here is the simple relay pass that extracts true/false branches. Any advice on how to get the relay expressions following control flow (Starting %12 to %16) in separate relay function or IRModule?
class ParitioningPass(ExprMutator):
"""This pass partitions the graph based on control flow
"""
def __init__(self):
super().__init__()
self.inputs = []
def visit_if(self, ite):
"""This function returns gives graphs aroun the IF ELSE. (condition, true branch and false
branch)
:param call:
:return:
"""
print("If")
tb = ite.true_branch
fb = ite.false_branch
cond = ite.cond