Improving dataflow pattern matching for relay

Summary

This pre-RFC is devoted to improving dataflow pattern matching for relay.

The improvement is two-fold. First, fix some bugs in the algorithm implementations of pattern matching. To implement correct algorithms, some new features may be introduced. Second, improve some pattern-matching-based optimizations of computational graph.

Motivation

Bugs in existing implementations

Previously, the implementations of pattern matching use greedy algorithms which may not be guaranteed to work correctly. And some subgraphs that should match the pattern are not actually matched. Below I list two such bugs.

bug with DominatorPattern

I post some bugs with DominatorPattern in the reply. Isues 1 and 2 are relatively easy to solve. But Issue 3 may root from the failure of a greedy matching algorithm.

bug with AltPattern

Consider the following example. When exchanging the left and right of the AltPathern, the matching results differ. This may also root from the failure of a greedy matching algorithm.

import os
os.environ["TVM_LOG_DEBUG"] = "DEFAULT=1"
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

def create_model():
    data = relay.var("data", relay.TensorType((2, 8), "float32"))
    feat1 = relay.nn.relu(data)
    feat2 = relay.nn.relu(feat1)
    feat3 = relay.nn.relu(feat2)
    feat4 = relay.add(feat3, feat2)
    f = relay.Function([data], feat4)
    mod = tvm.IRModule()
    mod["main"] = f
    mod = relay.transform.InferType()(mod)
    return mod

def construct_pattern(mode):
    relu_pattern = is_op("nn.relu")(wildcard())
    if mode == 1:
        or_pattern = is_op("nn.relu")(is_op("nn.relu")(relu_pattern)) | is_op("nn.relu")(relu_pattern)
    else:
        or_pattern = is_op("nn.relu")(relu_pattern) | is_op("nn.relu")(is_op("nn.relu")(relu_pattern))
    return is_op("add")( or_pattern, relu_pattern)

if __name__ == "__main__":
    pattern1 = construct_pattern(1)
    pattern2 = construct_pattern(2)
    mod = create_model()
    kaka1 = pattern1.match(mod["main"].body)
    kaka2 = pattern2.match(mod["main"].body)
    print(kaka1) # 0
    print(kaka2) # 1

Computational graph optimizations in existing codebase

The TVM community developed a pattern matching language which is currently widely used within the TVM project. The language consists of, among others, ExprPattern, CallPattern, AltPattern. While the current use of pattern matching language is sufficient in many situations, there are some important senarios where the current use of pattern matching language is not powerful enough. Let me take two examples:

  1. In SimplifyClipAndConsecutiveCast pass, it matches:
clip_ = IsOp("clip")({IsWildcard()});
cast1_ = IsOp("cast")({clip_});
pattern_ = IsOp("cast")({cast1_});

which is restricted. What if there is more than two cast OP? What if there are some transpose and reshape OPs between clip and cast? The current use of pattern matching language seems hard to match such patterns. Also, in SimplifyClipAndConsecutiveCast pass, one can not use the strategy of “remove only one cast op at one rewrite”.

To see the restriction of the current SimplifyClipAndConsecutiveCast pass, let us consider the following piece of computation:

 %45 = clip(%44, a_min=0f, a_max=255f) /* ty=Tensor[(1, 4096), float32] */;
 %46 = cast(%45, dtype="uint8") /* ty=Tensor[(1, 4096), uint8] */;
 %47 = cast(%46, dtype="int32") /* ty=Tensor[(1, 4096), int32] */;
 %48 = cast(%47, dtype="float32") /* ty=Tensor[(1, 4096), float32] */;

The above computation should be simplified, but it is currently not.

  1. Consider the following computation example whose pattern may likely arise when using quantization:
  %2 = transpose(%1, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), float32] */;
  %3 = multiply(%2, 16f /* ty=float32 */) /* ty=Tensor[(2, 8, 1, 1), float32] */;
  %4 = round(%3) /* ty=Tensor[(2, 8, 1, 1), float32] */;
  %5 = clip(%4, a_min=-127f, a_max=127f) /* ty=Tensor[(2, 8, 1, 1), float32] */;
  %6 = cast(%5, dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;

We would like to optimize the above computation such that transpose is computed after cast so that transpose is computed in int8. But transpose and cast is somewhat far away, and it may not be a good idea to assume the OPs between transpose and cast are exactly multiply, round and clip. We may want to do a general optimization from:

  %2 = transpose(%1, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), float32] */;
  %3 = elementwise(%2);
  %4 = elementwise(%3);
  ...
  %n = elementwise(%(n-1));
  %(n+1) = cast(%n, dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;

to

  %2 = elementwise(%1); // float32
  %3 = elementwise(%2);
  ...
  %(n-1) = elementwise(%(n-2));
  %(n) = cast(%(n-1), dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;
  %(n+1) = transpose(%n, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), int8] */;

Such general optimization may not be done via the current use of pattern matching language.

Guide-level explanation

We introduce the redirecting operation to dataflow pattern graph, which can redirect a WildcardPattern to a specified DFPattern. A WildcardPattern with no redirecting is just like the previous WildcardPattern, and can match an arbitrary Expr. On the other hand, a WildcardPattern redirected to a new DFPattern will match DFPattern.

To illustrate the use of redirecting operation, consider the following example. Suppose we would like to match an nn.dense followed by an arbitrary number of cast. we would like to define a pattern the_pattern as follows:

dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect =  wildcard()
the_pattern = is_op("cast")(wildcard_redirect | dense_pattern)
wildcard_redirect.redirect_to(the_pattern)

Reference-Level Explanation

I plan to implement the followings in this pre-RFC.

  • Add the redirecting operation API in c++ and python, with tests.
  • Fix PrettyPrint. Previously, if the dataflow pattern graph is not a DAG, PrettyPrint may fall in a dead loop
  • Fix algorithmic bugs in pattern matching
  • Use the improved pattern matching language to enhance and simplify the passes in simplify_expr
  • Refactor the implementation of dominate pattern via redirecting operation.
  • Use the improved pattern matching language to refactor python/tvm/relay/quantize. I think that with the composition of pattern, the pattern matching language is powerful enough to rewrite the logic of quanzie. (possible)

Rationale and alternatives

Note that, in a computation graph, a node is not allowed to point to itself (since it is a DAG). For pattern graph, however, a graph with cycles can be allowed. Once noticing this point, it would be better to rethink many aspects of existing code.

Prior Art

The composition of dataflow patterns is not a new idea. In this post, masahi described a situation where recursion is useful, and mbrookhart expressed that the recursion of pattern was at least considered. However, it seems that the composition of patterns has never been used in previous TVM codebase.

Future Possibilities

Any suggestions and comments are appreciated.

The first PR has been filed which solves PrettyPrint issue for composition.

The second PR has been filed which improves SimplifyClipAndConsecutiveCast via recursion of dataflow patterns.

The third PR has been filed which adds the redirecting operation, allowing one to define recursive patterns.

Sorry for replying in this a bit late. I think one thing that needs to be part of the tradeoff here is the overall algorithm complexity. Recursion is indeed powerful but also can have an issue of growing complexity in matching if not managed well.

Ideally we should need some restricted cases, that avoids infinite amount of recursion. For example, in your case it might be easier to specify things as

dominate(parent=cast, path=I, child=dense)

and make pattern I a cast pattern, the domination pattern should handle the case where all the intermediate results along the path follows the pattern I. The idea is that we have a pattern that contains begin, end and path(between begin and end). This is effectively 1 level of recursive checking, but should be sufficient for cases like consecutive cast

Overall it would be good to express pattern with finite amount of recursion, which simplifies reasoning and matching

see some examples about dominate pattern https://github.com/apache/tvm/blob/main/tests/python/relay/test_dataflow_pattern.py#L605

Thank you for your reply.

Some of your concerns are interesing and deep, and I would like to respond latter.

For now, I would like to respond to your concern about dominate pattern.

The problem

I agree that dominate pattern may be more convenient to use in some senarios. But the current implementation of dominate pattern may not be optimal. In fact, I should say there may be some bugs in the current logic of dominate pattern.

Consider the test case you pointed to:

# Pattern
P = is_op("nn.conv2d")(wildcard(), wildcard())  # 'parent'
I = is_op("nn.relu")(wildcard())  # 'intermediate' ('path' in the code)
C = is_op("add")(wildcard(), wildcard())  # 'child'
pattern = dominates(P, I, C)
#       n6(P)
#      /  \
#     n7   \
#    /      \
#    n8(P)  n10(I)
#    \      /
#    n9(I) /
#      \  /
#      n11(C)
x = relay.var("x")
w = relay.var("w")
n6 = relay.op.nn.conv2d(x, w)  # matches P
n7 = relay.op.tanh(n6)  # does not match I
n8 = relay.op.nn.conv2d(n7, w)  # matches P
n9 = relay.op.nn.relu(n8)  # matches I
n10 = relay.op.nn.relu(n6)  # matches I
n11 = relay.add(n9, n10)  # matches C
# Does not match: Can't match the parent pattern P at both 8 and 6.
# Note that if we did allow P to be used twice the implementation would
# need to be changed to not 'jump over' n7.
assert not pattern.match(n11)

The above example works well.

However, if we remove n9, and let the first argument of n11 be n8, then bugs occur:

# Pattern
P = is_op("nn.conv2d")(wildcard(), wildcard())  # 'parent'
I = is_op("nn.relu")(wildcard())  # 'intermediate' ('path' in the code)
C = is_op("add")(wildcard(), wildcard())  # 'child'
pattern = dominates(P, I, C)
#       n6(P)
#      /  \
#     n7   \
#    /      \
#    n8(P)  n10(I)
#    \      /
#     \    /
#      \  /
#      n11(C)
x = relay.var("x")
w = relay.var("w")
n6 = relay.op.nn.conv2d(x, w)  # matches P
n7 = relay.op.tanh(n6)  # does not match I
n8 = relay.op.nn.conv2d(n7, w)  # matches P
# n9 = relay.op.nn.relu(n8)  # matches I
n10 = relay.op.nn.relu(n6)  # matches I
n11 = relay.add(n8, n10)  # matches C

pattern.match(n11) # !!! True
pattern.partition(n11) #  !!! InternalError: Check failed: ....

That is, there occurs two bugs: the first one is that the matching result is wrong. The second is that the partition cause an error.

I find that a similar error was previously reported.

After a quick research, I think the root of the error may lie in two places. The first one is in DFPatternMatcher::MatchesPath. The second one is in that PatternGrouper has special logic for DominatePattern.

So the currently implementation of dominate pattern complicates the logic of matching and grouping.

Also, there may be some confusion in the definition of domination, see this post.

How to solve them

A possible approach is to make clear the definition of the dominate pattern and fix all bugs. But I am afraid this may not be the most elegant way.

In fact, with the redirecting operation proposed in this pre-RFC, the dominate pattern can be implemented as the syntax sugar of a special recursion. In this approach, there may be no special treatment of DominatePattern in matching and grouping.

I would like to refactor the implementation of dominate pattern as a syntax sugar in this pre-RFC. How about it?

Thanks for the reply.

The main tradeoff here is the complexity of the matching and search. There is no disagreement that recursion brings more expressiveness.

The tradeoff here is that it also brings complexity, both in the way we implement it (e.g. whether nested recursion is involved in matching), and overall reasoning (it is easier to reason about a two-level matching than the recursive one).

So when it is possible to solve the problem with limited level of recursion (like dominator), I think we should try to improve that direction, instead of pushing things to a more general one.

Thank you for your reply.

You mention two complexities: the complexity we implement patterns with recursion, and the pattern for reasoning it. To minimize these complexities, I will adjust the plan to do the followings first:

  1. solve the bugs in the current version of dominate pattern.
  2. rethinking the optimizations I have done via pattern matching, and express the patterns by dominate pattern if possible.
  3. Use the redirecting operation only for the case that the pattern can not be expressed by existing pattern language.

The algorithm efficiency is a deep topic. Note that GroupMatches, Partition and Rewriting requires traversing the full computational graph. In the current implementation, in the best case that every attempt to match a subgraph succeeds, and every AltPattern matches its left branch, the algorithm may only need to visit each node of computational graph once. In this view, perhaps the efficiency can be characterized by: 1. the probability of failure to match (including the failure to match a branch of AltPattern) and 2. and the average length of the subgraph of the computational graph that fails to match.

A recursive pattern does not necessarily have a larger probability of failure, nor a long average length of the subgraph that fails to match. So I think the recursion itself may not directly affect the algorithm efficiency. An interesting future research direction is the faster algorithm for matching (perhaps algorithms for regular expression matching may shed some lights).

Now I discuss the dominator pattern. The aim is to detect the bugs in the current version of dominator pattern, and make clear its definition. Before we start, we note that (if I am not getting wrong), the dominator pattern is not used in the codebase except for tests.

Case 1 (a relatively easy one)

Consider the fllowing example which is a little modified version of the test case:

# Pattern
P = is_op("nn.conv2d")(wildcard(), wildcard())  # 'parent'
I = is_op("nn.relu")(wildcard())  # 'intermediate' ('path' in the code)
C = is_op("add")(wildcard(), wildcard())  # 'child'
pattern = dominates(P, I, C)
#       n6(P)
#      /  \
#     n7   \
#    /      \
#    n8(P)  n10(I)
#    \      /
#     \    /
#      \  /
#      n11(C)
x = relay.var("x")
w = relay.var("w")
n6 = relay.op.nn.conv2d(x, w)  # matches P
n7 = relay.op.tanh(n6)  # does not match I
n8 = relay.op.nn.conv2d(n7, w)  # matches P
# n9 = relay.op.nn.relu(n8)  # matches I
n10 = relay.op.nn.relu(n6)  # matches I
n11 = relay.add(n8, n10)  # matches C
pattern.match(n11) # !!! True
pattern.partition(n11) #  !!! InternalError: Check failed: ....

So there arises two issues:

issue 1.

Currently, it seems that n11 matches the pattern as follows: n11 is the child and n8 is the parent. Is this the behavior we want? In general, in dominator pattern, should we require that every arg of the child should match the path pattern? The current logic is written in DFPatternMatcher::MatchesPath which indicates: for an OP in the path, if its first arg matches the (perhaps memoized) parent pattern, then the OP is regarded as a matched path no matter what its second arg is; however, if its first arg does not match the (perhaps memoized) parent pattern, then the first arg should match path to make the overal match success. The above rule treats the first arg and the second arg unequally, which may not be disirable, perhaps a design bug or an implementation bug. In the PR, the fixed logic is: every arg should match the parent pattern of the path pattern (and eventually a parent pattern).

issue 2.

Why partition causes a problem? It turns out that in PatternGrouper::CreateGroup, it is assumed that the path pattern matches at least one OP. This assumption does not necessarily hold for the current matching rule. So do we require that path pattern should match at least one OP?

In the PR, the fixed logic is: we do not require that path pattern should match at least one OP.

Issue 1 and Issue 2 are not big deal, they may be imediately fixed once we make clear the desired behavior of dominate pattern.

Case 2 (a serious one)

@maartenvds reported an interesting bug.

There may be 3 bugs in this case, including issue 1 and issue 2 and a new and serious bug.

To make the new bug simple to illustrate, consider the following simple example:

Suppose we would like to match a dominator pattern:

relu = is_op("nn.relu")(wildcard())
add = is_op("add")(wildcard(), wildcard())
pattern = dominates(relu, wildcard(), add)

Consider the computational graph:

x = relay.var('x', relay.TensorType([1, 1, 32, 32], 'float32'))
x1 = relay.nn.relu(x)
a = relay.nn.relu(x1)
b = relay.nn.relu(x1)
y = a + b
#       x
#       |
#      x1(P)
#      /  \
#     /    \
#    /      \
#   a        b
#    \      /
#     \    /
#      \  /
#       y(C)
pattern.partition(y) # WRONG RESULT!!!

issue 3.

Even if we solve Issue 1 and Issue 2, it may still get wrong result in this case.

In fact, when visiting a, it is memorized that a is the parent. And if the branch b does not eventually reach a, it returns unmatched. However, the real parent is x1. This means that, the current algorithm for dominator pattern is in fact wrong.

Issue 3 may not be easily solved with current infras. I would like to solve Issue 3 via the newly introduced redirecting operation.

Since redirecting has not been merged yet, currently I only solve Issue 1 and Issue 2.

@tqchen @masahi @mbrookhart @comaniac How do you think issue 3? Is there any another elegent way to solve it?