Thanks for your reply. I will take a look at how to use debug_executor to enable such a function.
Also, I am asking this question because this is somehow related to pipeline execution. Thus, I am still wondering is it possible for that user can we register operations in Relay IR as new outputs.
In my case, I split 12 layer BERT models into two subgraphs, the first one contains the first 4 layers and the second one contains the last 8 layers. When following Relay IR rules, both subgraphs only have one global output, wherein is the last operation.
BERT in Relay IR:
running now: BERT original mod: #[version = "0.0.5"] fn (%tf_bert_for_sequence_classification/bert/embeddings/Gather/resource: Tensor[(30522, 768), float32], %x: Tensor[(1, 128), int32], %tf_bert_for_sequence_classification/bert/embeddings/Gather_1/resource: Tensor[(512, 768), float32], %tf_bert_for_sequence_classification/bert/embeddings/Gather_2/resource: Tensor[(2, 768), float32], %tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/value/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/value/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/output/dense/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/output/dense/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/output/LayerNorm/batchnorm/mul/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/output/LayerNorm/batchnorm/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/intermediate/dense/Tensordot/ReadVariableOp/resource: Tensor[(768, 3072), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/intermediate/dense/BiasAdd/ReadVariableOp/resource: Tensor[(3072), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/output/dense/Tensordot/ReadVariableOp/resource: Tensor[(3072, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/output/dense/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/output/LayerNorm/batchnorm/mul/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._0/output/LayerNorm/batchnorm/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/self/query/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/self/query/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/self/key/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/self/key/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/self/value/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/self/value/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/output/dense/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/output/dense/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/output/LayerNorm/batchnorm/mul/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/attention/output/LayerNorm/batchnorm/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/intermediate/dense/Tensordot/ReadVariableOp/resource: Tensor[(768, 3072), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/intermediate/dense/BiasAdd/ReadVariableOp/resource: Tensor[(3072), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/output/dense/Tensordot/ReadVariableOp/resource: Tensor[(3072, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/output/dense/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/output/LayerNorm/batchnorm/mul/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._1/output/LayerNorm/batchnorm/ReadVariableOp/resource: Tensor[(768), float32], ....%tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/self/query/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/self/query/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/self/key/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/self/key/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/self/value/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/self/value/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/output/dense/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/output/dense/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/output/LayerNorm/batchnorm/mul/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/attention/output/LayerNorm/batchnorm/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/intermediate/dense/Tensordot/ReadVariableOp/resource: Tensor[(768, 3072), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/intermediate/dense/BiasAdd/ReadVariableOp/resource: Tensor[(3072), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/output/dense/Tensordot/ReadVariableOp/resource: Tensor[(3072, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/output/dense/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/output/LayerNorm/batchnorm/mul/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._11/output/LayerNorm/batchnorm/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/pooler/dense/MatMul/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/pooler/dense/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/classifier/MatMul/ReadVariableOp/resource: Tensor[(768, 2), float32], %tf_bert_for_sequence_classification/classifier/BiasAdd/ReadVariableOp/resource: Tensor[(2), float32]) { %0 = expand_dims(meta[relay.Constant][0], axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/ExpandDims */; %1 = take(%tf_bert_for_sequence_classification/bert/embeddings/Gather_1/resource, %0, axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/Gather_1 */; %2 = take(%tf_bert_for_sequence_classification/bert/embeddings/Gather/resource, %x, axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/Gather */; %3 = tile(%1, reps=[1, 1, 1]) /* tf_bert_for_sequence_classification/bert/embeddings/Tile */; %4 = full(0, shape=[1, 128], dtype="int32") /* tf_bert_for_sequence_classification/bert/Fill_1 */; %5 = add(%2, %3) /* tf_bert_for_sequence_classification/bert/embeddings/add/add */; %6 = take(%tf_bert_for_sequence_classification/bert/embeddings/Gather_2/resource, %4, axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/Gather_2 */; %7 = add(%5, %6) /* tf_bert_for_sequence_classification/bert/embeddings/add/add_1 */; %8 = mean(%7, axis=[2], keepdims=True) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/moments/mean */; %9 = subtract(%7, %8); %10 = multiply(%9, %9) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/moments/SquaredDifference */; %11 = mean(%10, axis=[2], keepdims=True) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/moments/variance */; %12 = add(%11, 1e-12f) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/add */; %13 = power(%12, -0.5f) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/Rsqrt */; %14 = multiply(%13, %tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul */; %15 = multiply(%8, %14) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul_2 */; %16 = multiply(%7, %14) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul_1 */; %17 = subtract(%tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/ReadVariableOp/resource, %15) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/sub */; %18 = add(%16, %17) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/add_1 */; %19 = reshape(%18, newshape=[128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot/Reshape */; %20 = transpose(%tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot/ReadVariableOp/resource, axes=[1, 0]); %21 = nn.dense(%19, %20, units=768) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot/MatMul */; %22 = reshape(%21, newshape=[1, 128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot */; %23 = add(%22, %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/BiasAdd/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/BiasAdd */; %24 = reshape(%23, newshape=[1, -1, 12, 64]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/Reshape */; %25 = transpose(%24, axes=[0, 2, 1, 3]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/transpose */; %26 = reshape(%18, newshape=[128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot/Reshape */; %27 = transpose(%tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot/ReadVariableOp/resource, axes=[1, 0]); %28 = nn.dense(%26, %27, units=768) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot/MatMul */; %29 = reshape(%28, newshape=[1, 128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot */; %30 = add(%29, %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/BiasAdd/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/BiasAdd */; %31 = reshape(%30, newshape=[1, -1, 12, 64]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/Reshape_1 */; %32 = transpose(%31, axes=[0, 2, 1, 3]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/transpose_1 */; %33 = reshape(%25, newshape=[12, 128, 64]); %34 = reshape(%32, newshape=[12, 128, 64]); %35 = nn.batch_matmul(%33, %34, transpose_b=True); %36 = reshape(%35, newshape=[1, 12, 128, 128]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/MatMul */; %37 = full(1, shape=[1, 128], dtype="int32") /* tf_bert_for_sequence_classification/bert/Fill */; %38 = reshape(%37, newshape=[1, 1, 1, 128]) /* tf_bert_for_sequence_classification/bert/Reshape */; %39 = cast(%38, dtype="float32") /* tf_bert_for_sequence_classification/bert/Cast */; %40 = subtract(1f, %39) /* tf_bert_for_sequence_classification/bert/Sub */; %41 = divide(%36, 8f) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/truediv */; %42 = multiply(%40, -10000f) /* tf_bert_for_sequence_classification/bert/Mul */; ..... %361 = reshape(%339, newshape=[128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/value/Tensordot/Reshape */; ... %974 = transpose(%tf_bert_for_sequence_classification/bert/pooler/dense/MatMul/ReadVariableOp/resource, axes=[1, 0]); %975 = nn.dense(%973, %974, units=768) /* tf_bert_for_sequence_classification/bert/pooler/dense/MatMul */; %976 = add(%975, %tf_bert_for_sequence_classification/bert/pooler/dense/BiasAdd/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/bert/pooler/dense/BiasAdd */; %977 = tanh(%976) /* tf_bert_for_sequence_classification/bert/pooler/dense/Tanh */; %978 = transpose(%tf_bert_for_sequence_classification/classifier/MatMul/ReadVariableOp/resource, axes=[1, 0]); %979 = nn.dense(%977, %978, units=2) /* tf_bert_for_sequence_classification/classifier/MatMul */; add(%979, %tf_bert_for_sequence_classification/classifier/BiasAdd/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/classifier/BiasAdd */ }
first subgraph (mod 0) in Relay IR:
mods 0: def @main(%tf_bert_for_sequence_classification/bert/embeddings/Gather/resource: Tensor[(30522, 768), float32], %x: Tensor[(1, 128), int32], .... (ignore){ %0 = expand_dims(meta[relay.Constant][0], axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/ExpandDims */; %1 = take(%tf_bert_for_sequence_classification/bert/embeddings/Gather_1/resource, %0, axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/Gather_1 */; %2 = take(%tf_bert_for_sequence_classification/bert/embeddings/Gather/resource, %x, axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/Gather */; %3 = tile(%1, reps=[1, 1, 1]) /* tf_bert_for_sequence_classification/bert/embeddings/Tile */; %4 = full(0, shape=[1, 128], dtype="int32") /* tf_bert_for_sequence_classification/bert/Fill_1 */; %5 = add(%2, %3) /* tf_bert_for_sequence_classification/bert/embeddings/add/add */; %6 = take(%tf_bert_for_sequence_classification/bert/embeddings/Gather_2/resource, %4, axis=0) /* tf_bert_for_sequence_classification/bert/embeddings/Gather_2 */; %7 = add(%5, %6) /* tf_bert_for_sequence_classification/bert/embeddings/add/add_1 */; %8 = mean(%7, axis=[2], keepdims=True) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/moments/mean */; %9 = subtract(%7, %8); %10 = multiply(%9, %9) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/moments/SquaredDifference */; %11 = mean(%10, axis=[2], keepdims=True) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/moments/variance */; %12 = add(%11, 1e-12f) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/add */; %13 = power(%12, -0.5f) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/Rsqrt */; %14 = multiply(%13, %tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul */; %15 = multiply(%8, %14) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul_2 */; %16 = multiply(%7, %14) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/mul_1 */; %17 = subtract(%tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/ReadVariableOp/resource, %15) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/sub */; %18 = add(%16, %17) /* tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/batchnorm/add_1 */; %19 = reshape(%18, newshape=[128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot/Reshape */; %20 = transpose(%tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot/ReadVariableOp/resource, axes=[1, 0]); %21 = nn.dense(%19, %20, units=768) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot/MatMul */; %22 = reshape(%21, newshape=[1, 128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/Tensordot */; %23 = add(%22, %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/BiasAdd/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/query/BiasAdd */; %24 = reshape(%23, newshape=[1, -1, 12, 64]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/Reshape */; %25 = transpose(%24, axes=[0, 2, 1, 3]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/transpose */; %26 = reshape(%18, newshape=[128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot/Reshape */; %27 = transpose(%tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot/ReadVariableOp/resource, axes=[1, 0]); %28 = nn.dense(%26, %27, units=768) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot/MatMul */; %29 = reshape(%28, newshape=[1, 128, 768]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/Tensordot */; %30 = add(%29, %tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/BiasAdd/ReadVariableOp/resource) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/key/BiasAdd */; %31 = reshape(%30, newshape=[1, -1, 12, 64]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/Reshape_1 */; %32 = transpose(%31, axes=[0, 2, 1, 3]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/transpose_1 */; %33 = reshape(%25, newshape=[12, 128, 64]); %34 = reshape(%32, newshape=[12, 128, 64]); %35 = nn.batch_matmul(%33, %34, transpose_b=True); %36 = reshape(%35, newshape=[1, 12, 128, 128]) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/MatMul */; %37 = full(1, shape=[1, 128], dtype="int32") /* tf_bert_for_sequence_classification/bert/Fill */; %38 = reshape(%37, newshape=[1, 1, 1, 128]) /* tf_bert_for_sequence_classification/bert/Reshape */; %39 = cast(%38, dtype="float32") /* tf_bert_for_sequence_classification/bert/Cast */; %40 = subtract(1f, %39) /* tf_bert_for_sequence_classification/bert/Sub */; %41 = divide(%36, 8f) /* tf_bert_for_sequence_classification/bert/encoder/layer_._0/attention/self/truediv */; %42 = multiply(%40, -10000f) /* tf_bert_for_sequence_classification/bert/Mul */; ..... (ignore)
second subgraph (mod 1) in Relay IR:
mods 1: def @main(%x: Tensor[(1, 128, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/query/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/query/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/key/Tensordot/ReadVariableOp/resource: Tensor[(768, 768), float32], %tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/key/BiasAdd/ReadVariableOp/resource: Tensor[(768), float32], %x1: Tensor[(1, 1, 1, 128), float32],.... %0 = reshape(%x, newshape=[128, 768]) /* ty=Tensor[(128, 768), float32] */; %1 = transpose(%tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/query/Tensordot/ReadVariableOp/resource, axes=[1, 0]) /* ty=Tensor[(768, 768), float32] */; %2 = nn.dense(%0, %1, units=768) /* ty=Tensor[(128, 768), float32] */; %3 = reshape(%2, newshape=[1, 128, 768]) /* ty=Tensor[(1, 128, 768), float32] */; %4 = add(%3, %tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/query/BiasAdd/ReadVariableOp/resource) /* ty=Tensor[(1, 128, 768), float32] */; %5 = reshape(%4, newshape=[1, -1, 12, 64]) /* ty=Tensor[(1, 128, 12, 64), float32] */; %6 = transpose(%5, axes=[0, 2, 1, 3]) /* ty=Tensor[(1, 12, 128, 64), float32] */; %7 = reshape(%x, newshape=[128, 768]) /* ty=Tensor[(128, 768), float32] */; %8 = transpose(%tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/key/Tensordot/ReadVariableOp/resource, axes=[1, 0]) /* ty=Tensor[(768, 768), float32] */; %9 = nn.dense(%7, %8, units=768) /* ty=Tensor[(128, 768), float32] */; %10 = reshape(%9, newshape=[1, 128, 768]) /* ty=Tensor[(1, 128, 768), float32] */; %11 = add(%10, %tf_bert_for_sequence_classification/bert/encoder/layer_._4/attention/self/key/BiasAdd/ReadVariableOp/resource) /* ty=Tensor[(1, 128, 768), float32] */; %12 = reshape(%11, newshape=[1, -1, 12, 64]) /* ty=Tensor[(1, 128, 12, 64), float32] */; %13 = transpose(%12, axes=[0, 2, 1, 3]) /* ty=Tensor[(1, 12, 128, 64), float32] */; %14 = reshape(%6, newshape=[12, 128, 64]) /* ty=Tensor[(12, 128, 64), float32] */; %15 = reshape(%13, newshape=[12, 128, 64]) /* ty=Tensor[(12, 128, 64), float32] */; %16 = nn.batch_matmul(%14, %15, transpose_b=True) /* ty=Tensor[(12, 128, 128), float32] */; %17 = reshape(%16, newshape=[1, 12, 128, 128]) /* ty=Tensor[(1, 12, 128, 128), float32] */; %18 = divide(%17, 8f /* ty=float32 */) /* ty=Tensor[(1, 12, 128, 128), float32] */; %19 = add(%18, %x1) /* ty=Tensor[(1, 12, 128, 128), float32] */; %20 = nn.softmax(%19) /* ty=Tensor[(1, 12, 128, 128), float32] */; .....
When I check the data dependency, I notice there are two data dependencies:
-
The last operation of the first subgraph → %x: Tensor[(1, 128, 768), float32] of the second subgraph For this one, I can follow the reference code shown below to set the output of the first output to go into the input of the second subgraph.
-
%42 of the first subgraph → %x1: Tensor[(1, 1, 1, 128), float32] of the second subgraph. This operation is constant that goes to every layer. (e.g, %19 in second subgraph) However, for this operation, I cannot send the data dependency to the next subgraph since it is not registered as global output in the first subgraph.
Thus, I am still wondering is it possible for that user can we register operations in Relay IR as new outputs to read them out (or send them to another subgraph, in my case).
Thanks for your help in advance.
cc @hjiang
============================================================
Reference code: