A pattern language for Relay
This RFC introduces a pattern language for describing topology of Relay graphs and matching them. The RFC is structured as follows, motivation, examples, design, and then a small mostly complete prototype with test cases.
There are many places in TVM where we identify pure data-flow sub-graphs of the Relay program and attempt to transform them in some way example passes include fusion, quantization, external code generation, and device specific optimizations such as bitpacking, and layer slicing used by VTA.
Many of these passes today require a lots of boring boilerplate code in order to implement as well as requiring users to think in terms of visitors and AST matching. Many of these transformations can easily be described in terms of graph rewrites. In order to build a rewriter or other advanced machinery we first need a language of patterns to describe what we can match.
Such a language is not just useful for building a rewriter but also providing extension points for existing passes. For example the fusion pass could be parametrized by a set of fusion patterns which describes the capability of your hardware, and the quantization pass could take a set of patterns which describe which operators can be quantized on a given platform.
In the backend world, we could use the same machinery to build a higher level API using bring your own code generation. This API takes set of patterns describing your hardware capabilities and an external compiler, providing a relatively smooth heterogeneous experience out of the box.
Recently there has been lots of discussion on similar issues in the community, and we wanted to gather feedback and hopefully collaborate on a design that can benefit everyone working in this space. This RFC focuses on the pattern language with future applications to come later.
There are quite a few properties that are worth matching of operators below we examine how to match tree properties, and expand on some use cases that are not fully explored in the prototype. The first example is a simple case where we want to match one operator with a single input OR another operator with a single input, see the below diagram for a graphical representation and corresponding code.
def test_match_op_or(): is_add_or_sub = is_op('add') | is_op('subtract') assert is_add_or_sub.match(relay.op.op.get("add")) assert is_add_or_sub.match(relay.op.op.get("subtract"))
The next example is a dense operation with any operator that is marked element-wise.
def test_no_match_attr(): op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) op_pat = op(wildcard(), wildcard()) x = relay.var('x') y = relay.var('y') assert not op_pat.match(relay.op.nn.dense(x, y))
The next example is matching a diamond with two inputs at the top of the diamond.
def test_match_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) path1 = is_op('nn.relu')(is_conv2d) path2 = is_op('nn.leaky_relu')(is_conv2d) diamond = is_op('add')(path1, path2) # Expr inp = relay.var('input') weight = relay.var('weight') conv2d = relay.op.nn.conv2d(inp, weight) relu = relay.op.nn.relu(conv2d) leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) out = relu + leaky_relu # Check assert diamond.match(out)
The final example we would like to match which is not yet implemented in the prototype is matching diamonds with a post-dominator relationship. Our plan is to embed dominator analysis as type of matching in the pattern language in order to allow for pattern matching with unknown topology. This is important because we want to able to use the language to describe fuse patterns, like elementwise operations followed by a conv2d.
def test_match_dom_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_elemwise, reduction) # Expr inp = relay.var('input') weight = relay.var('weight') conv2d = relay.op.nn.conv2d(inp, weight) relu = relay.op.nn.relu(conv2d) leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) out = relu + leaky_relu # Check assert diamond.match(out)
The pattern language proposed is designed to be a mirror of Relay’s IR with additional support for common scenarios. The goal of the pattern language is to provide a regular-expression like capability for matching data-flow graphs and doing rewriting.
The high level design is to introduce a language of patterns for now we propose the language as:
Pattern ::= expr | * | pattern(pattern1, ... patternN) | has_type(pattern, type) | has_attr(pattern, attr, attr_value) | is_input | named(name, pattern) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern)
The above language then provides a matching interface with both can select sub-graphs as well as verify that the graph does match the pattern.
Match a literal expression.
Match any expression.
Check that the expression matched by the nested pattern has a particular type.
Check that the operator matched by the pattern has an attribute with a particular value.
Check that the expression is an input, i.e has no parents and is a variable.
Name a pattern group for later naming the sub-group.
Either match the first pattern or the second pattern.
Match the parent pattern for the route node and then check that the child pattern holds for each child along the domination path.
Above I enumerate the current working design and types of patterns. I have built an initial prototype that implements most of the patterns but is lacking named patterns and dominator patterns. Matthew Brookhart (now at OctoML) is spearheading the implementation of this and we will hopefully contribute a work-in-progress implementation soon.
We can grow the above language in a variety of ways including multi-outputs, for example we can match finite multiple outputs using a design like below.
def test_multi_output(): # Pattern is_2outputs = is_op('2outputs')(is_input(), is_input()) is_output1 = is_2outputs is_output2 = is_2outputs named # Expr inp = relay.var('input') weight = relay.var('weight') result = 2outputs(inp, weight) output1 = result output2 = result # Check assert is_output1.match(output1) assert is_output2.match(output2)
For a working version of the pattern language see this GitHub Gist https://gist.github.com/jroesch/de9648556b1f7e1bb3db490e0f708813.