Hi All,
I am trying to extract the function which is partitioned with Relay pattern partition primitive.
Here is an example that I use relay pattern partition to partition the graph. My goal is to extract the partition function into a standalone function or IRModule?
# Define a pattern. A pattern matching conv2d+relu.
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
# Create a graph of relay expressions
x1 = relay.var('x1', shape=(1, 1))
y1 = relay.var('y1', shape=(1, 1))
add1 = relay.op.add(x1, y1)
x = add1
#x = relay.var('input', shape=(1, 1))
w = relay.var('weight', shape=(1, 1))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
f = relay.Function([], relu)
print("-------------------Partition----------------")
part_result = pattern.partition(relu)
mod_part = tvm.IRModule({"main": f})
This the partitioned graph. How can I extract the %2 = fn from here?
print(part_result)
free_var %x1: Tensor[(1, 1), float32];
free_var %y1: Tensor[(1, 1), float32];
%0 = add(%x1, %y1);
free_var %weight: Tensor[(1, 1), float32];
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.relu_") {
%1 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
nn.relu(%1)
};
%2(%0, %weight)