Since the partitioned function is an anonymous function, you cannot access it by name such as mod["func_1"]
. I think the the easiest way is to write a simple Relay pass to collect it.
class FuncCollector(tvm.relay.ExprVisitor):
def __init__(self):
super().__init__()
self.funcs = set()
def visit_call(self, call):
op = call.op
if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs:
self.funcs.add(op)
collector = FuncCollector()
collector.visit(part_result)
print("Functions:")
for func in collector.funcs:
print(func)
With your example:
Functions:
fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
nn.relu(%0)
}