%0 = expand_dims(%input_ids, axis=-1) /* bert/embeddings/ExpandDims */;
%1 = reshape(%0, newshape=[-1]) /* bert/embeddings/Reshape */;
%2 = take(%bert/embeddings/word_embeddings, %1, axis=0) /* bert/embeddings/GatherV2 */;
%3 = reshape(%segment_ids, newshape=[-1]) /* bert/embeddings/Reshape_2 */;
%4 = one_hot(%3, 1f, 0f, depth=2, dtype="float32") /* bert/embeddings/one_hot */;
%5 = transpose(%bert/embeddings/token_type_embeddings, axes=[1, 0]);
%6 = nn.dense(%4, %5, units=768) /* bert/embeddings/MatMul */;
%7 = reshape(%2, newshape=[8, 384, 768]) /* bert/embeddings/Reshape_1 */;
%8 = reshape(%6, newshape=[8, 384, 768]) /* bert/embeddings/Reshape_3 */;
%9 = strided_slice(%bert/embeddings/position_embeddings, begin=[0, 0], end=[384, -1], strides=[1, 1], slice_mode="size", axes=None) /* bert/embeddings/Slice */;
%10 = add(%7, %8) /* bert/embeddings/add */;
%11 = reshape(%9, newshape=[1, 384, 768]) /* bert/embeddings/Reshape_4 */;
%12 = add(%10, %11) /* bert/embeddings/add_1 */;
%13 = mean(%12, axis=[2], keepdims=True) /* bert/embeddings/LayerNorm/moments/mean */;
%14 = subtract(%12, %13);
%15 = multiply(%14, %14) /* bert/embeddings/LayerNorm/moments/SquaredDifference */;
%16 = mean(%15, axis=[2], keepdims=True) /* bert/embeddings/LayerNorm/moments/variance */;
%17 = add(%16, 1e-12f) /* bert/embeddings/LayerNorm/batchnorm/add */;
%18 = power(%17, -0.5f) /* bert/embeddings/LayerNorm/batchnorm/Rsqrt */;
%19 = multiply(%18, %bert/embeddings/LayerNorm/gamma) /* bert/embeddings/LayerNorm/batchnorm/mul */;
%20 = multiply(%13, %19) /* bert/embeddings/LayerNorm/batchnorm/mul_2 */;
%21 = multiply(%12, %19) /* bert/embeddings/LayerNorm/batchnorm/mul_1 */;
%22 = subtract(%bert/embeddings/LayerNorm/beta, %20) /* bert/embeddings/LayerNorm/batchnorm/sub */;
%23 = add(%21, %22) /* bert/embeddings/LayerNorm/batchnorm/add_1 */;
%24 = reshape(%23, newshape=[-1, 768]) /* bert/encoder/Reshape_1 */;
%25 = transpose(%bert/encoder/layer_0/attention/self/query/kernel, axes=[1, 0]);
%26 = nn.dense(%24, %25, units=768) /* bert/encoder/layer_0/attention/self/query/MatMul */;
%27 = add(%26, %bert/encoder/layer_0/attention/self/query/bias) /* bert/encoder/layer_0/attention/self/query/BiasAdd */;
%28 = reshape(%27, newshape=[8, 384, 12, 64]) /* bert/encoder/layer_0/attention/self/Reshape */;
%29 = transpose(%28, axes=[0, 2, 1, 3]) /* bert/encoder/layer_0/attention/self/transpose */;
%30 = transpose(%bert/encoder/layer_0/attention/self/key/kernel, axes=[1, 0]);
%31 = nn.dense(%24, %30, units=768) /* bert/encoder/layer_0/attention/self/key/MatMul */;
%32 = add(%31, %bert/encoder/layer_0/attention/self/key/bias) /* bert/encoder/layer_0/attention/self/key/BiasAdd */;
%33 = reshape(%32, newshape=[8, 384, 12, 64]) /* bert/encoder/layer_0/attention/self/Reshape_1 */;
%34 = transpose(%33, axes=[0, 2, 1, 3]) /* bert/encoder/layer_0/attention/self/transpose_1 */;
%35 = reshape(%29, newshape=[96, 384, 64]);
%36 = reshape(%34, newshape=[96, 384, 64]);
%37 = nn.batch_matmul(%35, %36, transpose_b=True);
dataflow pattern can’t match I define op.
i want fuse the layernorm op above.
class LayerNormRewriter(DFPatternCallback):
def __init__(self, require_type=False, rewrite_once=False):
super().__init__(require_type, rewrite_once)
'''
%10 = add(%7, %8) /* bert/embeddings/add */;
%11 = reshape(%9, newshape=[1, 384, 768]) /* bert/embeddings/Reshape_4 */;
%12 = add(%10, %11) /* bert/embeddings/add_1 */;
%13 = mean(%12, axis=[2], keepdims=True) /* bert/embeddings/LayerNorm/moments/mean */;
%14 = subtract(%12, %13);
%15 = multiply(%14, %14) /* bert/embeddings/LayerNorm/moments/SquaredDifference */;
%16 = mean(%15, axis=[2], keepdims=True) /* bert/embeddings/LayerNorm/moments/variance */;
%17 = add(%16, 1e-12f) /* bert/embeddings/LayerNorm/batchnorm/add */;
%18 = power(%17, -0.5f) /* bert/embeddings/LayerNorm/batchnorm/Rsqrt */;
%19 = multiply(%18, %bert/embeddings/LayerNorm/gamma) /* bert/embeddings/LayerNorm/batchnorm/mul */;
%20 = multiply(%13, %19) /* bert/embeddings/LayerNorm/batchnorm/mul_2 */;
%21 = multiply(%12, %19) /* bert/embeddings/LayerNorm/batchnorm/mul_1 */;
%22 = subtract(%bert/embeddings/LayerNorm/beta, %20) /* bert/embeddings/LayerNorm/batchnorm/sub */;
%23 = add(%21, %22) /* bert/embeddings/LayerNorm/batchnorm/add_1 */;
%24 = reshape(%23, newshape=[-1, 768]) /* bert/encoder/Reshape_1 */;
%25 = transpose(%bert/encoder/layer_0/attention/self/query/kernel, axes=[1, 0]);
%26 = nn.dense(%24, %25, units=768) /* bert/encoder/layer_0/attention/self/query/MatMul */;
'''
self.data = wildcard()
self.mean1 = is_op("mean")(self.data)
self.sub1 = self.data - self.mean1
self.mul1 = self.sub1 * self.sub1
self.pattern = self.mul1
def callback(self, pre, post, node_map):
import pdb
pdb.set_trace()
return relay.nn.layer_norm(node_map[self.x], node_map[self.gamma], node_map[self.beta], epsilon=relay.const(1e-12, dtype="float32"))
def rewite_layer_norm(mod):
mod["main"] = rewrite(LayerNormRewriter(), mod["main"])
return mod
If matches the pattern I defined, the breakpoint will be invoked. But no thing happends.
Anyone could help me ?