How to extract a relay function when doing partitioning with relay pattern?

Since the partitioned function is an anonymous function, you cannot access it by name such as mod["func_1"]. I think the the easiest way is to write a simple Relay pass to collect it.

class FuncCollector(tvm.relay.ExprVisitor):
    def __init__(self):
        super().__init__()
        self.funcs = set()

    def visit_call(self, call):
        op = call.op
        if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs:
            self.funcs.add(op)

collector = FuncCollector()
collector.visit(part_result)
print("Functions:")
for func in collector.funcs:
    print(func)

With your example:

Functions:
fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.relu_") {
  %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
  nn.relu(%0)
}
1 Like