@comaniac, @mbrookhart Thanks very much for taking a look at this! I tried to simplify my model to get the test script. You could use this to debug.
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import is_op, wildcard, is_var, is_constant
def ss_pattern():
inp = is_var("inp")
inp_mask = is_var("inp_mask")
attn_query_weight = is_var("attn_query_weight")
attn_query_bias = is_var("attn_query_bias")
attn_key_weight = is_var("attn_key_weight")
attn_key_bias = is_var("attn_key_bias")
v149 = wildcard()
v70 = wildcard()
v72 = wildcard()
op17 = is_op("reshape")(inp)
op18 = is_op("reshape")(attn_query_weight)
op19 = is_op("transpose")(op18)
op20 = is_op("nn.batch_matmul")(op17, op19)
op21 = is_op("reshape")(op20)
op22 = is_op("add")(op21, attn_query_bias)
op23 = is_op("reshape")(op22)
op24 = is_op("transpose")(op23)
op25 = is_op("reshape")(op24)
op26 = is_op("reshape")(inp)
op27 = is_op("reshape")(attn_key_weight)
op28 = is_op("transpose")(op27)
op29 = is_op("nn.batch_matmul")(op26, op28)
op30 = is_op("reshape")(op29)
op31 = is_op("add")(op30, attn_key_bias)
op32 = is_op("reshape")(op31)
op33 = is_op("transpose")(op32)
op34 = is_op("reshape")(op33)
op35 = is_op("transpose")(op34)
op36 = is_op("nn.batch_matmul")(op25, op35)
op37 = is_op("reshape")(op36)
op38 = is_op("divide")(op37, v149)
op39 = is_op("cast")(inp_mask)
op40 = is_op("expand_dims")(op39)
op41 = is_op("expand_dims")(op40)
op42 = is_op("cast")(op41) #true
op43 = is_op("subtract")(v70, op42)
op44 = is_op("multiply")(op43, v72)
op45 = is_op("add")(op38, op44)
return op45
def getMod():
inp = relay.var('inp', shape=(1,64,512))
inp_mask = relay.var('inp_mask', shape=(1,64), dtype='int32')
attn_query_weight = relay.var('attn_query_weight', shape=(1,512,512))
attn_query_bias = relay.var('attn_query_bias', shape=(1,64,512))
attn_key_weight = relay.var('attn_key_weight', shape=(1,512,512))
attn_key_bias = relay.var('attn_key_bias', shape=(1,64,512))
v149 = relay.const(8, "float32")
v70 = relay.const(1, "float32")
v72 = relay.const(-10000, "float32")
op17 = relay.reshape(inp, newshape=[-1, 64, 512])
op18 = relay.reshape(attn_query_weight, newshape=[-1, 512, 512])
op19 = relay.transpose(op18, axes=[0, 2, 1])
op20 = relay.nn.batch_matmul(op17, op19)
op21 = relay.reshape(op20, newshape=[1, 64, 512])
op22 = relay.add(op21, attn_query_bias)
op23 = relay.reshape(op22, newshape=[1, 64, 8, 64])
op24 = relay.transpose(op23, axes=[0, 2, 1, 3])
op25 = relay.reshape(op24, newshape=[-1, 64, 64])
op26 = relay.reshape(inp, newshape=[-1, 64, 512])
op27 = relay.reshape(attn_key_weight, newshape=[-1, 512, 512])
op28 = relay.transpose(op27, axes=[0, 2, 1])
op29 = relay.nn.batch_matmul(op26, op28)
op30 = relay.reshape(op29, newshape=[1, 64, 512])
op31 = relay.add(op30, attn_key_bias)
op32 = relay.reshape(op31, newshape=[1, 64, 8, 64])
op33 = relay.transpose(op32, axes=[0, 2, 3, 1])
op34 = relay.reshape(op33, newshape=[-1, 64, 64])
op35 = relay.transpose(op34, axes=[0, 2, 1])
op36 = relay.nn.batch_matmul(op25, op35)
op37 = relay.reshape(op36, newshape=[1, 8, 64, 64])
op38 = relay.divide(op37, v149)
op39 = relay.cast(inp_mask, dtype="int64")
op40 = relay.expand_dims(op39, axis=1)
op41 = relay.expand_dims(op40, axis=2)
op42 = relay.cast(op41, dtype="float32")
op43 = relay.subtract(v70, op42)
op44 = relay.multiply(op43, v72)
op45 = relay.add(op38, op44)
# repeat from here
op107 = relay.reshape(op45, newshape=[-1, 64, 512])
op108 = relay.reshape(attn_query_weight, newshape=[-1, 512, 512])
op109 = relay.transpose(op108, axes=[0, 2, 1])
op110 = relay.nn.batch_matmul(op107, op109)
op111 = relay.reshape(op110, newshape=[1, 64, 512])
op112 = relay.add(op111, attn_query_bias)
op113 = relay.reshape(op112, newshape=[1, 64, 8, 64])
op114 = relay.transpose(op113, axes=[0, 2, 1, 3])
op115 = relay.reshape(op114, newshape=[-1, 64, 64])
op116 = relay.reshape(op45, newshape=[-1, 64, 512])
op117 = relay.reshape(attn_key_weight, newshape=[-1, 512, 512])
op118 = relay.transpose(op117, axes=[0, 2, 1])
op119 = relay.nn.batch_matmul(op116, op118)
op120 = relay.reshape(op119, newshape=[1, 64, 512])
op121 = relay.add(op120, attn_key_bias)
op122 = relay.reshape(op121, newshape=[1, 64, 8, 64])
op123 = relay.transpose(op122, axes=[0, 2, 3, 1])
op124 = relay.reshape(op123, newshape=[-1, 64, 64])
op125 = relay.transpose(op124, axes=[0, 2, 1])
op126 = relay.nn.batch_matmul(op115, op125)
op127 = relay.reshape(op126, newshape=[1, 8, 64, 64])
op128 = relay.divide(op127, v149)
op129 = relay.add(op128, op44) #reuse the op44 here
return op129
def mytest():
mod = tvm.IRModule.from_expr(getMod())
pattern_table = [("test_pattern", ss_pattern())]
mod = relay.transform.MergeComposite(pattern_table)(mod)
In this mod, the ss_pattern has been repeated twice. One thing to mention is that the op44 is reused in both. I guess that is something matters.