Replicating operator with pattern rewrite got stuck in an infinite loop

Hi All,

I was experimenting to decompose a relay expression (operator or a function) into the same kind of operator. Mainly, if I have to do vector addition, I just want to replicate the number of additions but each addition operates on part of the data.

For example, to add two vectors of size 30 by 1, let us assume we need one vector adder.

I just want to use two adders, and I want to delegate each adder to operate on half (15 =30/2) of the data.

Here is an example that I write using rewrite. I thought rewrite could do it easily, but for some reason, the following program is stuck in an infinite loop in callback function. Any ideas?

class ExampleDecomposeByShape(DFPatternCallback):

    def __init__(self):
        self.x = wildcard()
        self.y = wildcard()
        self.patern =  self.x + self.y 

    def callback(self, pre, post, node_map):
        **# This callback is stuck. It  keeps returning to the next line.** 
        x = node_map[self.x][0]
        y = node_map[self.y][0]
        
        dim1, dim2 =  node_map[self.x][0].type_annotations.shape
        
        x1 =  relay.var('x1', shape(dim1//2, dim2))
        y1 =  relay.var('y1', shape(dim1//2, dim2))
        x2 =  relay.var('x2', shape(dim1//2, dim2))
        y2 =  relay.var('y2', shape(dim1//2, dim2))
 
         res =  relay.op.concatenate([x1 + y2, x2+ y1], axis=0)
         return res

# This is how I call
x =  relay.var('x', shape(30, 1))
y =  relay.var('y', shape(30, 1))
f= relay.Function([x, y], x+y)

out = rewrite(ExampleDecomposeByShape(), f)
print(out) 

At the first glance I guess it’s due to the fact that your new graph still matches the pattern (x1 + y2 for example), but I’m not sure if this should be handled by the rewriter and the callback.

@mbrookhart do you have any suggestion?

I think the issue is that at some point, dim1//2 is going to end up being 0, so you’ll forever add together empty tensors. A simple fix for the infinite loop would be to do a check on dim1 and only continue the lowering if dim1 > 1.

Also, the graph you construct in the callback doesn’t actually use the input data, did you mean to split x and y into x1, x2 and y1, y2 instead of creating new variables?