Now I’m trying to produce a pattern that matches nodes if they have the same shape.
Is such a pattern available? I only saw has_shape which seems to compare to a fixed shape (which I don’t know).
I’m trying to use rewrite and so it seems checking after the matching (an returning an unchanged expression) will lead to an infinite loop.
Thank you, yes.
So I have this graph produced by gradient (and graph normal form and removing the forward outputs) of a dense + bias_add. Obviously, the gradients would be ones_like(output).collapse_like(bias) and a couple of dense( ) with grad_out or its transpose replacing weight and input, respectively for getting the gradient for the other.
class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
def __init__(self):
self.pattern_tensor = tvm.relay.dataflow_pattern.wildcard()
self.zeros_like = tvm.relay.dataflow_pattern.is_op("zeros_like")(self.pattern_tensor)
self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern = self.zeros_like + self.other_tensor
def callback(self, pre, post, node_map):
rt = node_map[self.pattern][0]
ot = node_map[self.other_tensor][0]
if (ot._checked_type_ == rt._checked_type_):
return ot
else:
return tvm.relay.broadcast_to(ot, rt._checked_type_.shape)
class CollapseSumZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
def __init__(self):
self.data_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern = tvm.relay.dataflow_pattern.is_op("collapse_sum_like")(self.data_tensor, self.pattern_tensor)
def callback(self, pre, post, node_map):
data = node_map[self.data_tensor][0]
res = node_map[self.pattern][0]
if (data._checked_type_ == res._checked_type_):
return data
else:
return res
grfn = tvm.relay.dataflow_pattern.rewrite(ZeroZapp(), grmod["main"])
grfn = tvm.relay.dataflow_pattern.rewrite(CollapseSumZapp(), grfn)
For the CollapseSumZapp in particular, I would replace the if in the callback by a more refined pattern. Similarly,
So from implicit broadcasting, I have many ops in the backward. The broadcast_like could probably treated just as collapse_sum_like.
Similarly, I might have a reshape, broadcast_to, … that where I have a shape annotation for the input and output or I could take the input shape and the shape argument, but I don’t know how to use these.
The infinite loop probably was from me doing stupid things (re-creating the final step of the caculation instead of returning the original one…).
I’m always wondering whether I’m missing ready-made passes of removing some of the typical overhead of automatic differentiation (e.g. replacing ..._like with static ops or removing broadcasting / collapse_sum etc. If not, would these be useful to make available?
Thank you Matt!
Oh no. (But checked_type isn’t the solution, unfortunately.)
I must admit the ffi is too clever for me. Without the tab completion I’m lost.
I even have a 2-line patch to fix that for classes, but I don’t know where to put the unittest…
There is another way types can go awry in the dataflow matcher. When things get mutated they lose their type info until the rewrite is completed. We might want to start treating that behaviour as a bug because it’s caught me out before. Maybe @mbrookhart can comment?
The sort of case I’m thinking of is when a mutation takes place, the mutated part of the graph won’t have types associated with it (at least, not until type_infer is called on the expression again). It’s not immediately obvious to me whether that’s happening in this example. But now I’ve thought about it more, that’s not a bug, it would just be a requirement that you manually propagate the type info in your mutator.
I agree with @mbaret. The checked_type_ would be empty when a node is created until InterType is run or a new function is added to the module. It means the later processing node may not get the type of its parents if the parents were replaced with new nodes without properly propogating their types. You could try to add new_node.checked_type_ = old_node.checked_type_.
Which doesn’t have a type when it is constructed, but ZeroZapp later can find that node and assume it does have a type. Thus, the problem.
If you’re expecting types in later passes, I think the best thing is to put InferType in your callback, or between passes as you’re doing here. We could think about adding that to the rewrite infrastructure, but as I’ve mentioned in other threads, I don’t particular want to force users to type their problems before using the pattern language in all cases.
@t-vi I’ll take a closer look at your examples and see if I can figure out a way to distill it into a more refined pattern.
I can see why. But so it seems that the shape processing gets really tedious here - with the inability to pass .shape back to relay because it is an array rather than a list being the final thing.
Maybe if there were some way of saying I want types…
@t-vi Sorry for my delay, I had a lot of meetings today. I’ve finally read through this enough to grok the problem. I’m not sure the Pattern Language is the right tool for this pass.
As you said here:
This looks like more of a need for removing dynamic ops. I’m actually working on a pass like that related to Dynamic Ops in Relay - #17 - pre-RFC - Apache TVM Discuss. The pass basically does a loop of infer_type, fold constants, replace dynamic ops with constant or static versions, repeat.
It doesn’t support many use cases yet, but I can imagine plugging ones_like/zeros_like/broadcast_to_like in that pass and getting this behavior in a fairly straightforward way.
Yeah, it all wants to be static static to operate on.
But so what I’m after is the next step, eliminate all ops not needed in a static setting.
This seems important for anything where the graph is created automatic - with the frontend converters as well as differentiation.