Summary
This pre-RFC is devoted to improving dataflow pattern matching for relay.
The improvement is two-fold. First, fix some bugs in the algorithm implementations of pattern matching. To implement correct algorithms, some new features may be introduced. Second, improve some pattern-matching-based optimizations of computational graph.
Motivation
Bugs in existing implementations
Previously, the implementations of pattern matching use greedy algorithms which may not be guaranteed to work correctly. And some subgraphs that should match the pattern are not actually matched. Below I list two such bugs.
bug with DominatorPattern
I post some bugs with DominatorPattern
in the reply.
Isues 1 and 2 are relatively easy to solve.
But Issue 3 may root from the failure of a greedy matching algorithm.
bug with AltPattern
Consider the following example.
When exchanging the left
and right
of the AltPathern
, the matching results differ.
This may also root from the failure of a greedy matching algorithm.
import os
os.environ["TVM_LOG_DEBUG"] = "DEFAULT=1"
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def create_model():
data = relay.var("data", relay.TensorType((2, 8), "float32"))
feat1 = relay.nn.relu(data)
feat2 = relay.nn.relu(feat1)
feat3 = relay.nn.relu(feat2)
feat4 = relay.add(feat3, feat2)
f = relay.Function([data], feat4)
mod = tvm.IRModule()
mod["main"] = f
mod = relay.transform.InferType()(mod)
return mod
def construct_pattern(mode):
relu_pattern = is_op("nn.relu")(wildcard())
if mode == 1:
or_pattern = is_op("nn.relu")(is_op("nn.relu")(relu_pattern)) | is_op("nn.relu")(relu_pattern)
else:
or_pattern = is_op("nn.relu")(relu_pattern) | is_op("nn.relu")(is_op("nn.relu")(relu_pattern))
return is_op("add")( or_pattern, relu_pattern)
if __name__ == "__main__":
pattern1 = construct_pattern(1)
pattern2 = construct_pattern(2)
mod = create_model()
kaka1 = pattern1.match(mod["main"].body)
kaka2 = pattern2.match(mod["main"].body)
print(kaka1) # 0
print(kaka2) # 1
Computational graph optimizations in existing codebase
The TVM community developed a pattern matching language which is currently widely used within the TVM project. The language consists of, among others, ExprPattern, CallPattern, AltPattern. While the current use of pattern matching language is sufficient in many situations, there are some important senarios where the current use of pattern matching language is not powerful enough. Let me take two examples:
- In
SimplifyClipAndConsecutiveCast
pass, it matches:
clip_ = IsOp("clip")({IsWildcard()});
cast1_ = IsOp("cast")({clip_});
pattern_ = IsOp("cast")({cast1_});
which is restricted.
What if there is more than two cast
OP?
What if there are some transpose
and reshape
OPs between clip
and cast
?
The current use of pattern matching language seems hard to match such patterns.
Also, in SimplifyClipAndConsecutiveCast
pass, one can not use the strategy of “remove only one cast op at one rewrite”.
To see the restriction of the current SimplifyClipAndConsecutiveCast
pass, let us consider the following piece of computation:
%45 = clip(%44, a_min=0f, a_max=255f) /* ty=Tensor[(1, 4096), float32] */;
%46 = cast(%45, dtype="uint8") /* ty=Tensor[(1, 4096), uint8] */;
%47 = cast(%46, dtype="int32") /* ty=Tensor[(1, 4096), int32] */;
%48 = cast(%47, dtype="float32") /* ty=Tensor[(1, 4096), float32] */;
The above computation should be simplified, but it is currently not.
- Consider the following computation example whose pattern may likely arise when using quantization:
%2 = transpose(%1, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%3 = multiply(%2, 16f /* ty=float32 */) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%4 = round(%3) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%5 = clip(%4, a_min=-127f, a_max=127f) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%6 = cast(%5, dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;
We would like to optimize the above computation such that transpose
is computed after cast
so that transpose
is computed in int8.
But transpose
and cast
is somewhat far away, and it may not be a good idea to assume the OPs between transpose
and cast
are exactly multiply
, round
and clip
.
We may want to do a general optimization from:
%2 = transpose(%1, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), float32] */;
%3 = elementwise(%2);
%4 = elementwise(%3);
...
%n = elementwise(%(n-1));
%(n+1) = cast(%n, dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;
to
%2 = elementwise(%1); // float32
%3 = elementwise(%2);
...
%(n-1) = elementwise(%(n-2));
%(n) = cast(%(n-1), dtype="int8") /* ty=Tensor[(2, 8, 1, 1), int8] */;
%(n+1) = transpose(%n, axes=[1, 0, 2, 3]) /* ty=Tensor[(2, 8, 1, 1), int8] */;
Such general optimization may not be done via the current use of pattern matching language.
Guide-level explanation
We introduce the redirecting operation to dataflow pattern graph, which can redirect a WildcardPattern to a specified DFPattern. A WildcardPattern with no redirecting is just like the previous WildcardPattern, and can match an arbitrary Expr. On the other hand, a WildcardPattern redirected to a new DFPattern will match DFPattern.
To illustrate the use of redirecting operation,
consider the following example.
Suppose we would like to match an nn.dense
followed by an arbitrary number of cast
.
we would like to define a pattern the_pattern
as follows:
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
the_pattern = is_op("cast")(wildcard_redirect | dense_pattern)
wildcard_redirect.redirect_to(the_pattern)
Reference-Level Explanation
I plan to implement the followings in this pre-RFC.
- Add the redirecting operation API in c++ and python, with tests.
- Fix PrettyPrint. Previously, if the dataflow pattern graph is not a DAG, PrettyPrint may fall in a dead loop
- Fix algorithmic bugs in pattern matching
- Use the improved pattern matching language to enhance and simplify the passes in
simplify_expr
- Refactor the implementation of dominate pattern via redirecting operation.
- Use the improved pattern matching language to refactor
python/tvm/relay/quantize
. I think that with the composition of pattern, the pattern matching language is powerful enough to rewrite the logic of quanzie. (possible)
Rationale and alternatives
Note that, in a computation graph, a node is not allowed to point to itself (since it is a DAG). For pattern graph, however, a graph with cycles can be allowed. Once noticing this point, it would be better to rethink many aspects of existing code.
Prior Art
The composition of dataflow patterns is not a new idea. In this post, masahi described a situation where recursion is useful, and mbrookhart expressed that the recursion of pattern was at least considered. However, it seems that the composition of patterns has never been used in previous TVM codebase.
Future Possibilities
Any suggestions and comments are appreciated.