Summary
This pre-RFC is devoted to improving dataflow pattern matching for relay.
The improvement is two-fold. First, fix some bugs in the algorithm implementations of pattern matching. To implement correct algorithms, some new features may be introduced. Second, improve some pattern-matching-based optimizations of computational graph.
Motivation
Bugs in existing implementations
Previously, the implementations of pattern matching use greedy algorithms which may not be guaranteed to work correctly. And some subgraphs that should match the pattern are not actually matched. Below I list two such bugs.
bug with DominatorPattern
I post some bugs with DominatorPattern in the reply.
Isues 1 and 2 are relatively easy to solve.
But Issue 3 may root from the failure of a greedy matching algorithm.
bug with AltPattern
Consider the following example.
When exchanging the left and right of the AltPathern, the matching results differ.
This may also root from the failure of a greedy matching algorithm.
import os
os.environ["TVM_LOG_DEBUG"] = "DEFAULT=1"
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def create_model():
data = relay.var("data", relay.TensorType((2, 8), "float32"))
feat1 = relay.nn.relu(data)
feat2 = relay.nn.relu(feat1)
feat3 = relay.nn.relu(feat2)
feat4 = relay.add(feat3, feat2)
f = relay.Function([data], feat4)
mod = tvm.IRModule()
mod["main"] = f
mod = relay.transform.InferType()(mod)
return mod
def construct_pattern(mode):
relu_pattern = is_op("nn.relu")(wildcard())
if mode == 1:
or_pattern = is_op("nn.relu")(is_op("nn.relu")(relu_pattern)) | is_op("nn.relu")(relu_pattern)
else:
or_pattern = is_op("nn.relu")(relu_pattern) | is_op("nn.relu")(is_op("nn.relu")(relu_pattern))
return is_op("add")( or_pattern, relu_pattern)
if __name__ == "__main__":
pattern1 = construct_pattern(1)
pattern2 = construct_pattern(2)
mod = create_model()
kaka1 = pattern1.match(mod["main"].body)
kaka2 = pattern2.match(mod["main"].body)
print(kaka1) # 0
print(kaka2) # 1
Computational graph optimizations in existing codebase
The TVM community developed a pattern matching language which is currently widely used within the TVM project. The language consists of, among others, ExprPattern, CallPattern, AltPattern. While the current use of pattern matching language is sufficient in many situations, there are some important senarios where the current use of pattern matching language is not powerful enough. Let me take two examples:
- In
SimplifyClipAndConsecutiveCastpass, it matches:
clip_ = IsOp("clip")({IsWildcard()});
cast1_ = IsOp("cast")({clip_});
pattern_ = IsOp("cast")({cast1_});
which is restricted.
What if there is more than two cast OP?
What if there are some transpose and reshape OPs between clip and cast?
The current use of pattern matching language seems hard to match such patterns.
Also, in SimplifyClipAndConsecutiveCast pass, one can not use the strategy of “remove only one cast op at one rewrite”.
To see the restriction of the current SimplifyClipAndConsecutiveCast pass, let us consider the following piece of computation:
%45 = clip(%44, a_min=0f, a_max=255f) /* ty=Tensor[(1, 4096), float32] */;
%46 = cast(%45, dtype="uint8") /* ty=Tensor[(1, 4096), uint8] */;
%47 = cast(%46, dtype="int32") /* ty=Tensor[(1, 4096), int32] */;
%48 = cast(%47, dtype="float32") /* ty=Tensor[(1, 4096), float32] */;
The above computation should be simplified, but it is currently not.
- Consider the following computation example whose pattern may likely arise when using quantization:
%2 = transpose(%1, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%3 = multiply(%2, 16f /* ty=float32 */) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%4 = round(%3) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%5 = clip(%4, a_min=-127f, a_max=127f) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%6 = cast(%5, dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;
We would like to optimize the above computation such that transpose is computed after cast so that transpose is computed in int8.
But transpose and cast is somewhat far away, and it may not be a good idea to assume the OPs between transpose and cast are exactly multiply, round and clip.
We may want to do a general optimization from:
%2 = transpose(%1, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%3 = elementwise(%2);
%4 = elementwise(%3);
...
%n = elementwise(%(n-1));
%(n+1) = cast(%n, dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;
to
%2 = elementwise(%1); // float32
%3 = elementwise(%2);
...
%(n-1) = elementwise(%(n-2));
%(n) = cast(%(n-1), dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;
%(n+1) = transpose(%n, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), int8] */;
Such general optimization may not be done via the current use of pattern matching language.
Guide-level explanation
We introduce the redirecting operation to dataflow pattern graph, which can redirect a WildcardPattern to a specified DFPattern. A WildcardPattern with no redirecting is just like the previous WildcardPattern, and can match an arbitrary Expr. On the other hand, a WildcardPattern redirected to a new DFPattern will match DFPattern.
To illustrate the use of redirecting operation,
consider the following example.
Suppose we would like to match an nn.dense followed by an arbitrary number of cast.
we would like to define a pattern the_pattern as follows:
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
the_pattern = is_op("cast")(wildcard_redirect | dense_pattern)
wildcard_redirect.redirect_to(the_pattern)
Reference-Level Explanation
I plan to implement the followings in this pre-RFC.
- Add the redirecting operation API in c++ and python, with tests.
- Fix PrettyPrint. Previously, if the dataflow pattern graph is not a DAG, PrettyPrint may fall in a dead loop
- Fix algorithmic bugs in pattern matching
- Use the improved pattern matching language to enhance and simplify the passes in
simplify_expr - Refactor the implementation of dominate pattern via redirecting operation.
- Use the improved pattern matching language to refactor
python/tvm/relay/quantize. I think that with the composition of pattern, the pattern matching language is powerful enough to rewrite the logic of quanzie. (possible)
Rationale and alternatives
Note that, in a computation graph, a node is not allowed to point to itself (since it is a DAG). For pattern graph, however, a graph with cycles can be allowed. Once noticing this point, it would be better to rethink many aspects of existing code.
Prior Art
The composition of dataflow patterns is not a new idea. In this post, masahi described a situation where recursion is useful, and mbrookhart expressed that the recursion of pattern was at least considered. However, it seems that the composition of patterns has never been used in previous TVM codebase.
Future Possibilities
Any suggestions and comments are appreciated.