Based on the offline discussion with @mbaret, users may have a requirment to match a pattern with constant nodes. For example, we may have a conv2d op with two arguments.
%0 = nn.conv2d(%x, %w)
After we bind the second argument with constants, it becomes:
%0 = nn.conv2d(%x, meta[relay.Constant][0])
Users may only want to match the second one in case they only want to support the conv2d with constant weights.
With the pattern language, we have several appraoches to achieve this goal:
A.1 Using Check Functions
We can implement a check function to check if a specific argument in the matched subgraph is a constant node. This solution is already available in upstream. An example can be found here:
The problem for this solution is that the check function implementation might be tedious if the pattern is complex.
A.2 Supporting is_const
Similar to is_input
, we may enhance pattern language to support is_const
so that we can support the following pattern:
conv2d = is_op('nn.conv2d')(wildcard(), is_const())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
A.3 Supporting All Nodes in Patterns
This should be in A.2 as well but a bit out of scope. A more general solution could be supporting all nodes in patterns. For example, pattern nodes like TuplePattern
and TupleGetItemPattern
explicitly check the node type. We can improve the pattern nodes to support all types of nodes in the TVM node system so that we can solve this problem in a more general way:
conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
One miner extension to this solution is to create a consistant alias for all pattern nodes. In other words, we do not expect users to use TuplePattern
/ ConstantPattern
diectly but is_tuple
/ is_const
, etc.
Any comments and suggestions are welcome.