Multiple patterns recognition of the MergeComposite

Hi,

I am defining an tts transformer pattern using the tvm.relay.dataflow_pattern. And when I use this pattern to match a model which has 3 tts inside, it fails to get them. With some debugging, I found during the MergeComposite, it indeed gets the pattern but doesn’t save the group as it returns here:

Well, I am wondering is this expected? As all the 3 tts in the model are complete and should not interfere with each other. Is it because the 3 patterns are sequential constructed in the model (As the endding of the first is the starting point of the second)?

Appreciate for any explanation of this or any suggestions that can skip this issue. Thanks in advance!

Another thing I am curious about is that I found this macher extractor will also be called during relay.build, and in that process, it can find and save the found 3 groups.

My test script is like

pattern_table = [("transformer_tts", tts_pattern.tts_pattern())]    
mod = relay.transform.MergeComposite(pattern_table)(mod)
print("------------------after mergeComposite-----------------")
mod["main"] = DeepCpuAnnotator().visit(mod["main"])
mod = relay.transform.PartitionGraph()(mod)
print("------------------after partition-----------------")
with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(mod, target=target)

And with some debug outputs I get this:

Not sure what are the building saved groups, are there any relationship or difference between this two calling of MatchExtractor::CreateGroup?

For now, I commented out the overlapping and outside-usage checks to make my model work. But I don’t think this is a good solution as I totally not understand what the checks are for, and worried that it might crash some other day. So politely @tqchen @comaniac here, could you take a look at this if you have some time?

IIUC, the check is to make sure pattern groups won’t overlap to each other. I could take a look when I got time. Meanwhile, do you have a minimal example to reproduce this issue?

Also cc @mbrookhart

Thanks for reporting the issue. When I set up that logic, I tried to allow for sequential patterns, but perhaps I made a bug in my logic and it’s being overly conservative. A minimal example for debugging would be very helpful.

@comaniac, @mbrookhart Thanks very much for taking a look at this! I tried to simplify my model to get the test script. You could use this to debug.

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import is_op, wildcard, is_var, is_constant

def ss_pattern():
    inp = is_var("inp")
    inp_mask = is_var("inp_mask")
    attn_query_weight = is_var("attn_query_weight")
    attn_query_bias = is_var("attn_query_bias")
    attn_key_weight = is_var("attn_key_weight")
    attn_key_bias = is_var("attn_key_bias")
    v149 = wildcard()
    v70 = wildcard()
    v72 = wildcard()
    op17 = is_op("reshape")(inp)
    op18 = is_op("reshape")(attn_query_weight)
    op19 = is_op("transpose")(op18)
    op20 = is_op("nn.batch_matmul")(op17, op19)
    op21 = is_op("reshape")(op20) 
    op22 = is_op("add")(op21, attn_query_bias)
    op23 = is_op("reshape")(op22)
    op24 = is_op("transpose")(op23)
    op25 = is_op("reshape")(op24) 
    op26 = is_op("reshape")(inp)
    op27 = is_op("reshape")(attn_key_weight)
    op28 = is_op("transpose")(op27)
    op29 = is_op("nn.batch_matmul")(op26, op28)
    op30 = is_op("reshape")(op29)
    op31 = is_op("add")(op30, attn_key_bias)
    op32 = is_op("reshape")(op31)
    op33 = is_op("transpose")(op32)
    op34 = is_op("reshape")(op33)
    op35 = is_op("transpose")(op34)
    op36 = is_op("nn.batch_matmul")(op25, op35)
    op37 = is_op("reshape")(op36) 
    op38 = is_op("divide")(op37, v149) 
    op39 = is_op("cast")(inp_mask)
    op40 = is_op("expand_dims")(op39)
    op41 = is_op("expand_dims")(op40)
    op42 = is_op("cast")(op41) #true
    op43 = is_op("subtract")(v70, op42) 
    op44 = is_op("multiply")(op43, v72) 
    op45 = is_op("add")(op38, op44) 
    return op45

def getMod():
    inp = relay.var('inp', shape=(1,64,512))
    inp_mask = relay.var('inp_mask', shape=(1,64), dtype='int32')
    attn_query_weight = relay.var('attn_query_weight', shape=(1,512,512))
    attn_query_bias = relay.var('attn_query_bias', shape=(1,64,512))
    attn_key_weight = relay.var('attn_key_weight', shape=(1,512,512))
    attn_key_bias = relay.var('attn_key_bias', shape=(1,64,512))
    v149 = relay.const(8, "float32")
    v70 = relay.const(1, "float32")
    v72 = relay.const(-10000, "float32")

    op17 = relay.reshape(inp, newshape=[-1, 64, 512]) 
    op18 = relay.reshape(attn_query_weight, newshape=[-1, 512, 512]) 
    op19 = relay.transpose(op18, axes=[0, 2, 1]) 
    op20 = relay.nn.batch_matmul(op17, op19) 
    op21 = relay.reshape(op20, newshape=[1, 64, 512]) 
    op22 = relay.add(op21, attn_query_bias) 
    op23 = relay.reshape(op22, newshape=[1, 64, 8, 64]) 
    op24 = relay.transpose(op23, axes=[0, 2, 1, 3]) 
    op25 = relay.reshape(op24, newshape=[-1, 64, 64])
    op26 = relay.reshape(inp, newshape=[-1, 64, 512]) 
    op27 = relay.reshape(attn_key_weight, newshape=[-1, 512, 512]) 
    op28 = relay.transpose(op27, axes=[0, 2, 1]) 
    op29 = relay.nn.batch_matmul(op26, op28) 
    op30 = relay.reshape(op29, newshape=[1, 64, 512]) 
    op31 = relay.add(op30, attn_key_bias) 
    op32 = relay.reshape(op31, newshape=[1, 64, 8, 64]) 
    op33 = relay.transpose(op32, axes=[0, 2, 3, 1]) 
    op34 = relay.reshape(op33, newshape=[-1, 64, 64])
    op35 = relay.transpose(op34, axes=[0, 2, 1])
    op36 = relay.nn.batch_matmul(op25, op35)
    op37 = relay.reshape(op36, newshape=[1, 8, 64, 64]) 
    op38 = relay.divide(op37, v149) 
    op39 = relay.cast(inp_mask, dtype="int64")
    op40 = relay.expand_dims(op39, axis=1)
    op41 = relay.expand_dims(op40, axis=2)
    op42 = relay.cast(op41, dtype="float32")
    op43 = relay.subtract(v70, op42)
    op44 = relay.multiply(op43, v72)
    op45 = relay.add(op38, op44) 
    # repeat from here
    op107 = relay.reshape(op45, newshape=[-1, 64, 512]) 
    op108 = relay.reshape(attn_query_weight, newshape=[-1, 512, 512]) 
    op109 = relay.transpose(op108, axes=[0, 2, 1]) 
    op110 = relay.nn.batch_matmul(op107, op109) 
    op111 = relay.reshape(op110, newshape=[1, 64, 512]) 
    op112 = relay.add(op111, attn_query_bias) 
    op113 = relay.reshape(op112, newshape=[1, 64, 8, 64]) 
    op114 = relay.transpose(op113, axes=[0, 2, 1, 3]) 
    op115 = relay.reshape(op114, newshape=[-1, 64, 64])
    op116 = relay.reshape(op45, newshape=[-1, 64, 512]) 
    op117 = relay.reshape(attn_key_weight, newshape=[-1, 512, 512]) 
    op118 = relay.transpose(op117, axes=[0, 2, 1]) 
    op119 = relay.nn.batch_matmul(op116, op118) 
    op120 = relay.reshape(op119, newshape=[1, 64, 512]) 
    op121 = relay.add(op120, attn_key_bias) 
    op122 = relay.reshape(op121, newshape=[1, 64, 8, 64]) 
    op123 = relay.transpose(op122, axes=[0, 2, 3, 1]) 
    op124 = relay.reshape(op123, newshape=[-1, 64, 64])
    op125 = relay.transpose(op124, axes=[0, 2, 1])
    op126 = relay.nn.batch_matmul(op115, op125)
    op127 = relay.reshape(op126, newshape=[1, 8, 64, 64]) 
    op128 = relay.divide(op127, v149) 
    op129 = relay.add(op128, op44) #reuse the op44 here

    return op129

def mytest():
    mod = tvm.IRModule.from_expr(getMod())
    pattern_table = [("test_pattern", ss_pattern())]    
    mod = relay.transform.MergeComposite(pattern_table)(mod)

In this mod, the ss_pattern has been repeated twice. One thing to mention is that the op44 is reused in both. I guess that is something matters.

Yeah, that’s the problem, you have a pattern with multiple outputs, and we haven’t implemented the mechanics to really support that. You need something that can match a pattern, then output both op44 and op45 as outputs. Right now, you’re only searching for patterns with op45 as an output.

Supporting this case is doable, but it involves some non-trivial extensions to the pattern matcher.

Also, your second repeat doesn’t seem to match the pattern you provided, it stops at the divide, op38 in the pattern. Maybe it makes to stop the pattern there and merge a slightly smaller version of the pattern? That should get you two instances of it. I’m not sure exactly what you’re trying to lower this to, is it an API-specific implementation of attention?

    op39 = is_op("cast")(inp_mask)
    op40 = is_op("expand_dims")(op39)
    op41 = is_op("expand_dims")(op40)
    op42 = is_op("cast")(op41) #true
    op43 = is_op("subtract")(v70, op42) 
    op44 = is_op("multiply")(op43, v72) 

This appears to be per-processing outside of the actual pattern. If you drop this from your pattern and do divide->add, with an external input, I think it will work.

Thanks for your effort. Yes, if I removed the small branch of op39->op44 from the pattern, it gets all the instances. I’ll see if that works for my case.

1 Like