I have a graph like below.
def @main(%input: Tensor[(1, 3, 3, 3), float32]) -> Tensor[(1, 2, 3, 3), float32] {
%0 = nn.conv2d(%input, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]);
%1 = nn.bias_add(%0, meta[relay.Constant][1]);
%2 = nn.relu(%1);
%3 = nn.conv2d(%2, meta[relay.Constant][2], padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]);
%4 = nn.bias_add(%3, meta[relay.Constant][3]);
nn.relu(%4)
}
I used a pattern (nn.conv → nn.bias_add → nn.relu) like below.
pattern = is_op("nn.conv2d")(pattern, is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
pattern = pattern.optional(is_op("nn.relu"))
Then, I got a graph after MergeComposite & PartitionGraph pass. (constant index is changed)
def @custom(...) -> Tensor[(1, 2, 3, 3), float32] {
%4 = fn (%FunctionVar_1_0: Tensor[(1, 3, 3, 3), float32], PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_", Composite="custom_lib.conv2d") -> Tensor[(1, 2, 3, 3), float32] {
%2 = nn.conv2d(%FunctionVar_1_0, meta[relay.Constant][2] /* ty=Tensor[(2, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) /* ty=Tensor[(1, 2, 3, 3), float32] */;
%3 = nn.bias_add(%2, meta[relay.Constant][3] /* ty=Tensor[(2), float32] */) /* ty=Tensor[(1, 2, 3, 3), float32] */;
nn.relu(%3) /* ty=Tensor[(1, 2, 3, 3), float32] */
};
%5 = %4(%custom_lib_0_i0) /* ty=Tensor[(1, 2, 3, 3), float32] */;
%6 = fn (%FunctionVar_0_0: Tensor[(1, 2, 3, 3), float32], PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_", Composite="custom_lib.conv2d") -> Tensor[(1, 2, 3, 3), float32] {
%0 = nn.conv2d(%FunctionVar_0_0, meta[relay.Constant][0] /* ty=Tensor[(2, 2, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=2, kernel_size=[3, 3]) /* ty=Tensor[(1, 2, 3, 3), float32] */;
%1 = nn.bias_add(%0, meta[relay.Constant][1] /* ty=Tensor[(2), float32] */) /* ty=Tensor[(1, 2, 3, 3), float32] */;
nn.relu(%1) /* ty=Tensor[(1, 2, 3, 3), float32] */
};
%6(%5) /* ty=Tensor[(1, 2, 3, 3), float32] */
}
I was not able to run this graph well using custom JSON Runtime similar with arm_compute_lib json runtime. Because const_idx_ in JSONRuntimeBase is not matched with preceding changed constant index.
If I used wildcard() instead of is_constant() in pattern, It worked like below.
pattern = is_op("nn.conv2d")(pattern, wildcard())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
pattern = pattern.optional(is_op("nn.relu"))
But I want to know how to modify JSON Runtime to run the graph. Why constant index changes after MergeComposite pass? And How do I solve this problem?