[Question] Partitioning DNN with `graph_split()` of `test_pipeline_executor`

Dear colleagues,

I am using the graph_split of test_pipeline_executor, following the official tutorial named Using Pipeline Executor in Relay.

Firstly, I partitioned AlexNet, which does not have branches and multiple data flows. The result is correct.

However, when I am trying to partition the ResNet, which has two data paths, e.g., partitioning after “add” (partition method a, red line) in the following image. The results turned out to be incorrect.

The partitioning code is:

split_config = [{"op_name": "add", "op_index": 0}]
subgraphs = graph_split(mod["main"], split_config, params)

Then I dive into the partitioned IR, I found that: If we partition after the add op, the Partition 2 will have two input args (namely data_n_0 and data_n_1). However, the partition function does not provide a connection between the output of Partition 1 and the input of Partition 2.

So my questions are:

  1. How could I simply get the connection between the output of former partition and the input of later partition?
  2. If I would like to partition after multiple blocks (partition b in the figure), can current graph_split() of test_pipeline_executor can do? If not, how could I perform it?
  3. If I would like to partition using the Fused IR graph, how could I do? I use the following
with tvm.transform.PassContext(opt_level=3):
    mod_opt = relay.build(mod, target=target, params=params)
print(mod_opt.ir_mod["main"])

I still can only get the graph with no fused ops. I would like to use the nn_conv2d_relu_xxx as the basic module to partition.

Thank you guys in advance!

Appendix

The partitioned IR of ResNet according to the above description.

def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=resnetv24_batchnorm0_fwd.data:0:0 */) {
  %0 = nn.batch_norm(%data, meta[relay.Constant][0] /* ty=Tensor[(3), float32] span=resnetv24_batchnorm0_fwd.resnetv24_batchnorm0_gamma:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=resnetv24_batchnorm0_fwd.resnetv24_batchnorm0_beta:0:0 */, meta[relay.Constant][2] /* ty=Tensor[(3), float32] span=resnetv24_batchnorm0_fwd.resnetv24_batchnorm0_running_mean:0:0 */, meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=resnetv24_batchnorm0_fwd.resnetv24_batchnorm0_running_var:0:0 */) /* ty=(Tensor[(1, 3, 224, 224), float32], Tensor[(3), float32], Tensor[(3), float32]) */;
  %1 = %0.0 /* ty=Tensor[(1, 3, 224, 224), float32] */;
  %2 = nn.conv2d(%1, meta[relay.Constant][4] /* ty=Tensor[(64, 3, 7, 7), float32] span=resnetv24_conv0_fwd.resnetv24_conv0_weight:0:0 */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = nn.batch_norm(%2, meta[relay.Constant][5] /* ty=Tensor[(64), float32] span=resnetv24_batchnorm1_fwd.resnetv24_batchnorm1_gamma:0:0 */, meta[relay.Constant][6] /* ty=Tensor[(64), float32] span=resnetv24_batchnorm1_fwd.resnetv24_batchnorm1_beta:0:0 */, meta[relay.Constant][7] /* ty=Tensor[(64), float32] span=resnetv24_batchnorm1_fwd.resnetv24_batchnorm1_running_mean:0:0 */, meta[relay.Constant][8] /* ty=Tensor[(64), float32] span=resnetv24_batchnorm1_fwd.resnetv24_batchnorm1_running_var:0:0 */) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %4 = %3.0 /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %5 = nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %6 = nn.max_pool2d(%5, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %7 = nn.batch_norm(%6, meta[relay.Constant][9] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm0_fwd.resnetv24_stage1_batchnorm0_gamma:0:0 */, meta[relay.Constant][10] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm0_fwd.resnetv24_stage1_batchnorm0_beta:0:0 */, meta[relay.Constant][11] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm0_fwd.resnetv24_stage1_batchnorm0_running_mean:0:0 */, meta[relay.Constant][12] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm0_fwd.resnetv24_stage1_batchnorm0_running_var:0:0 */) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %8 = %7.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %9 = nn.relu(%8) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %10 = nn.conv2d(%9, meta[relay.Constant][13] /* ty=Tensor[(64, 64, 1, 1), float32] span=resnetv24_stage1_conv0_fwd.resnetv24_stage1_conv0_weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %11 = nn.batch_norm(%10, meta[relay.Constant][14] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm1_fwd.resnetv24_stage1_batchnorm1_gamma:0:0 */, meta[relay.Constant][15] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm1_fwd.resnetv24_stage1_batchnorm1_beta:0:0 */, meta[relay.Constant][16] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm1_fwd.resnetv24_stage1_batchnorm1_running_mean:0:0 */, meta[relay.Constant][17] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm1_fwd.resnetv24_stage1_batchnorm1_running_var:0:0 */) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %12 = %11.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %13 = nn.relu(%12) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %14 = nn.conv2d(%13, meta[relay.Constant][18] /* ty=Tensor[(64, 64, 3, 3), float32] span=resnetv24_stage1_conv1_fwd.resnetv24_stage1_conv1_weight:0:0 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %15 = nn.batch_norm(%14, meta[relay.Constant][19] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm2_fwd.resnetv24_stage1_batchnorm2_gamma:0:0 */, meta[relay.Constant][20] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm2_fwd.resnetv24_stage1_batchnorm2_beta:0:0 */, meta[relay.Constant][21] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm2_fwd.resnetv24_stage1_batchnorm2_running_mean:0:0 */, meta[relay.Constant][22] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm2_fwd.resnetv24_stage1_batchnorm2_running_var:0:0 */) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %16 = %15.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %17 = nn.relu(%16) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %18 = nn.conv2d(%17, meta[relay.Constant][23] /* ty=Tensor[(256, 64, 1, 1), float32] span=resnetv24_stage1_conv2_fwd.resnetv24_stage1_conv2_weight:0:0 */, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 56, 56), float32] */;
  %19 = nn.conv2d(%9, meta[relay.Constant][24] /* ty=Tensor[(256, 64, 1, 1), float32] span=resnetv24_stage1_conv3_fwd.resnetv24_stage1_conv3_weight:0:0 */, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 56, 56), float32] */;
  add(%18, %19) /* ty=Tensor[(1, 256, 56, 56), float32] */
}


def @main(%data_n_0: Tensor[(1, 256, 56, 56), float32] /* ty=Tensor[(1, 256, 56, 56), float32] */, %data_n_1: Tensor[(1, 256, 56, 56), float32] /* ty=Tensor[(1, 256, 56, 56), float32] */) {
  %0 = nn.batch_norm(%data_n_0, meta[relay.Constant][0] /* ty=Tensor[(256), float32] span=resnetv24_stage1_batchnorm3_fwd.resnetv24_stage1_batchnorm3_gamma:0:0 */, meta[relay.Constant][1] /* ty=Tensor[(256), float32] span=resnetv24_stage1_batchnorm3_fwd.resnetv24_stage1_batchnorm3_beta:0:0 */, meta[relay.Constant][2] /* ty=Tensor[(256), float32] span=resnetv24_stage1_batchnorm3_fwd.resnetv24_stage1_batchnorm3_running_mean:0:0 */, meta[relay.Constant][3] /* ty=Tensor[(256), float32] span=resnetv24_stage1_batchnorm3_fwd.resnetv24_stage1_batchnorm3_running_var:0:0 */) /* ty=(Tensor[(1, 256, 56, 56), float32], Tensor[(256), float32], Tensor[(256), float32]) */;
  %1 = %0.0 /* ty=Tensor[(1, 256, 56, 56), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 256, 56, 56), float32] */;
  %3 = nn.conv2d(%2, meta[relay.Constant][4] /* ty=Tensor[(64, 256, 1, 1), float32] span=resnetv24_stage1_conv4_fwd.resnetv24_stage1_conv4_weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %4 = nn.batch_norm(%3, meta[relay.Constant][5] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm4_fwd.resnetv24_stage1_batchnorm4_gamma:0:0 */, meta[relay.Constant][6] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm4_fwd.resnetv24_stage1_batchnorm4_beta:0:0 */, meta[relay.Constant][7] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm4_fwd.resnetv24_stage1_batchnorm4_running_mean:0:0 */, meta[relay.Constant][8] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm4_fwd.resnetv24_stage1_batchnorm4_running_var:0:0 */) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %5 = %4.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %7 = nn.conv2d(%6, meta[relay.Constant][9] /* ty=Tensor[(64, 64, 3, 3), float32] span=resnetv24_stage1_conv5_fwd.resnetv24_stage1_conv5_weight:0:0 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %8 = nn.batch_norm(%7, meta[relay.Constant][10] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm5_fwd.resnetv24_stage1_batchnorm5_gamma:0:0 */, meta[relay.Constant][11] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm5_fwd.resnetv24_stage1_batchnorm5_beta:0:0 */, meta[relay.Constant][12] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm5_fwd.resnetv24_stage1_batchnorm5_running_mean:0:0 */, meta[relay.Constant][13] /* ty=Tensor[(64), float32] span=resnetv24_stage1_batchnorm5_fwd.resnetv24_stage1_batchnorm5_running_var:0:0 */) /* ty=(Tensor[(1, 64, 56, 56), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %9 = %8.0 /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %10 = nn.relu(%9) /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %11 = nn.conv2d(%10, meta[relay.Constant][14] /* ty=Tensor[(256, 64, 1, 1), float32] span=resnetv24_stage1_conv6_fwd.resnetv24_stage1_conv6_weight:0:0 */, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 56, 56), float32] */;
  %12 = add(%11, %data_n_1) /* ty=Tensor[(1, 256, 56, 56), float32] */;