How to extract a relay function when doing partitioning with relay pattern?

Hi All,

I am trying to extract the function which is partitioned with Relay pattern partition primitive.

Here is an example that I use relay pattern partition to partition the graph. My goal is to extract the partition function into a standalone function or IRModule?

# Define a pattern. A pattern matching conv2d+relu.
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))

# Create a graph of relay expressions
x1 = relay.var('x1', shape=(1, 1))
y1 = relay.var('y1', shape=(1, 1))
add1 = relay.op.add(x1, y1)

x = add1
#x = relay.var('input', shape=(1, 1))
w = relay.var('weight', shape=(1, 1))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)

f = relay.Function([], relu)

print("-------------------Partition----------------")
part_result = pattern.partition(relu)
mod_part = tvm.IRModule({"main": f})

This the partitioned graph. How can I extract the %2 = fn from here?

print(part_result)
free_var %x1: Tensor[(1, 1), float32];
free_var %y1: Tensor[(1, 1), float32];
%0 = add(%x1, %y1);
free_var %weight: Tensor[(1, 1), float32];
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.relu_") {
  %1 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
  nn.relu(%1)
};
%2(%0, %weight)

Hi @comaniac , @masahi and @jroesch,

Any suggestions for the above problem? How can I extract the partitioned relay function from relay IRModule?

Since the partitioned function is an anonymous function, you cannot access it by name such as mod["func_1"]. I think the the easiest way is to write a simple Relay pass to collect it.

class FuncCollector(tvm.relay.ExprVisitor):
    def __init__(self):
        super().__init__()
        self.funcs = set()

    def visit_call(self, call):
        op = call.op
        if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs:
            self.funcs.add(op)

collector = FuncCollector()
collector.visit(part_result)
print("Functions:")
for func in collector.funcs:
    print(func)

With your example:

Functions:
fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.relu_") {
  %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
  nn.relu(%0)
}
1 Like