[Question] how to merge multi brach cast op in to one cast?

Below IR, %586 will produce int32 output, the output will be used in two branches. Each branch has a cast op, to cast the input to int8. I want make the two branch into one branch, so that the output could fuse the cast output to int8. make fewer op invocations.

As I know, pass like combine_paralle_dense helps to merge 3 branch conv to 1 conv. Could I reuse the pass to make paralle cast to one cast ? Or any other good choose pass ?

  %586 = fn (%p098: Tensor[(16, 64, 112, 112), int8], %p179: Tensor[(64, 1, 1), int32], Primitive=1, hash="b07808ae4c39a75b", layout="NCHW", out_layout="") -> Tensor[(16, 64, 56, 56), int32] {
    %562 = nn.max_pool2d(%p098, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(16, 64, 56, 56), int8] */;
    %563 = cast(%562, dtype="int32") /* ty=Tensor[(16, 64, 56, 56), int32] */;
    %564 = fixed_point_multiply(%563, multiplier=0, shift=0) /* ty=Tensor[(16, 64, 56, 56), int32] */;
    %565 = add(%564, %p179) /* ty=Tensor[(16, 64, 56, 56), int32] */;
    %566 = nn.relu(%565) /* ty=Tensor[(16, 64, 56, 56), int32] */;
    %567 = cast(%566, dtype="int64") /* ty=Tensor[(16, 64, 56, 56), int64] */;
    %568 = fixed_point_multiply(%567, multiplier=0, shift=0) /* ty=Tensor[(16, 64, 56, 56), int64] */;
    %569 = clip(%568, a_min=-127f, a_max=127f) /* ty=Tensor[(16, 64, 56, 56), int64] */;
    cast(%569, dtype="int32") /* ty=Tensor[(16, 64, 56, 56), int32] */
  };
  %587 = %586(%585, meta[relay.Constant][2] /* ty=Tensor[(64, 1, 1), int32] */) /* ty=Tensor[(16, 64, 56, 56), int32] */;
  %588 = fn (%p097: Tensor[(16, 64, 56, 56), int32], Primitive=1, hash="fa643999fd987de1") -> Tensor[(16, 64, 56, 56), int8] {
    cast(%p097, dtype="int8") /* ty=Tensor[(16, 64, 56, 56), int8] */
  };
  %589 = %588(%587) /* ty=Tensor[(16, 64, 56, 56), int8] */;
  %600 = fn (%p0102: Tensor[(16, 64, 56, 56), int32], Primitive=1, hash="fa643999fd987de1") -> Tensor[(16, 64, 56, 56), int8] {
    cast(%p0102, dtype="int8") /* ty=Tensor[(16, 64, 56, 56), int8] */
  };
  %601 = %600(%587) /* ty=Tensor[(16, 64, 56, 56), int8] */;

consider below IR graph:

def @main /* id=76660736 */(%data: Tensor[(1, 16), float32]) -> Tensor[(1, 16), int8] {
  %0 = cast(%data, dtype="int32") /* ty=Tensor[(1, 16), int32] */;
  %1 = cast(%0, dtype="int8") /* ty=Tensor[(1, 16), int8] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 16), int8] */;
  %3 = cast(%0, dtype="int8") /* ty=Tensor[(1, 16), int8] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16), int8] */;
  %5 = nn.dense(%2, meta[relay.Constant][0] /* ty=Tensor[(16, 16), float32] */, units=None) /* ty=Tensor[(1, 16), int8] */;
  %6 = nn.dense(%4, meta[relay.Constant][1] /* ty=Tensor[(16, 16), float32] */, units=None) /* ty=Tensor[(1, 16), int8] */;
  add(%5, %6) /* ty=Tensor[(1, 16), int8] */
}

the final optimized graph may looks like below:

def @main /* id=76660736 */(%data: Tensor[(1, 16), float32], param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0], hash="e7029313c01517bd") -> Tensor[(1, 16), int8] {
  %2 = fn (%p02: Tensor[(1, 16), float32], Primitive=1, hash="7cf59ec644668506") -> Tensor[(1, 16), int32] {
    cast(%p02, dtype="int32") /* ty=Tensor[(1, 16), int32] */
  };
  %3 = %2(%data) /* ty=Tensor[(1, 16), int32] */;
  %4 = fn (%p01: Tensor[(1, 16), int32], Primitive=1, hash="e0392345e073f5a2") -> Tensor[(1, 16), int8] {
    %1 = cast(%p01, dtype="int8") /* ty=Tensor[(1, 16), int8] */;
    nn.relu(%1) /* ty=Tensor[(1, 16), int8] */
  };
  %6 = fn (%p04: Tensor[(1, 16), int32], Primitive=1, hash="e0392345e073f5a2") -> Tensor[(1, 16), int8] {
    %5 = cast(%p04, dtype="int8") /* ty=Tensor[(1, 16), int8] */;
    nn.relu(%5) /* ty=Tensor[(1, 16), int8] */
  };
  %7 = %6(%3) /* ty=Tensor[(1, 16), int8] */;
  %8 = fn (%p03: Tensor[(1, 16), int8], %p11: Tensor[(16, 16), float32], Primitive=1, hash="51a9dc624dc4ddf6") -> Tensor[(1, 16), int8] {
    nn.dense(%p03, %p11, units=None) /* ty=Tensor[(1, 16), int8] */
  };
  %9 = %4(%3) /* ty=Tensor[(1, 16), int8] */;
  %10 = %8(%7, meta[relay.Constant][1] /* ty=Tensor[(16, 16), float32] */) /* ty=Tensor[(1, 16), int8] */;
  %11 = fn (%p0: Tensor[(1, 16), int8], %p1: Tensor[(16, 16), float32], %p2: Tensor[(1, 16), int8], Primitive=1, hash="1290d3847b0634e8") -> Tensor[(1, 16), int8] {
    %0 = nn.dense(%p0, %p1, units=None) /* ty=Tensor[(1, 16), int8] */;
    add(%0, %p2) /* ty=Tensor[(1, 16), int8] */
  };
  %11(%9, meta[relay.Constant][0] /* ty=Tensor[(16, 16), float32] */, %10) /* ty=Tensor[(1, 16), int8] */
}

It fused five kernels, the network flows int32 between dense and before relu. so if we combine the multi branch cast the final IR graph may look like:

def @main /* id=76826848 */(%data: Tensor[(1, 16), float32], param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0], hash="16cc69aef39ca29b") -> Tensor[(1, 16), int8] {
  %3 = fn (%p01: Tensor[(1, 16), float32], Primitive=1, hash="482b0c61c3b9b60f") -> Tensor[(1, 16), int8] {
    %1 = cast(%p01, dtype="int32") /* ty=Tensor[(1, 16), int32] */;
    %2 = cast(%1, dtype="int8") /* ty=Tensor[(1, 16), int8] */;
    nn.relu(%2) /* ty=Tensor[(1, 16), int8] */
  };
  %4 = %3(%data) /* ty=Tensor[(1, 16), int8] */;
  %5 = fn (%p02: Tensor[(1, 16), int8], %p11: Tensor[(16, 16), float32], Primitive=1, hash="51a9dc624dc4ddf6") -> Tensor[(1, 16), int8] {
    nn.dense(%p02, %p11, units=None) /* ty=Tensor[(1, 16), int8] */
  };
  %6 = %5(%4, meta[relay.Constant][1] /* ty=Tensor[(16, 16), float32] */) /* ty=Tensor[(1, 16), int8] */;
  %7 = fn (%p0: Tensor[(1, 16), int8], %p1: Tensor[(16, 16), float32], %p2: Tensor[(1, 16), int8], Primitive=1, hash="1290d3847b0634e8") -> Tensor[(1, 16), int8] {
    %0 = nn.dense(%p0, %p1, units=None) /* ty=Tensor[(1, 16), int8] */;
    add(%0, %p2) /* ty=Tensor[(1, 16), int8] */
  };
  %7(%4, meta[relay.Constant][0] /* ty=Tensor[(16, 16), float32] */, %6) /* ty=Tensor[(1, 16), int8] */
}

only has 3 kernels, and the dataflow between dense and relu is int8.

the ir may comes from resnet50.

I found CombineParallelOpBatch seems could doing this.

Pass CombineParallelOpBatch(const String& op_name, const String& batch_op_name,
                            uint64_t min_num_branches) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        return Downcast<Function>(
            CombineParallelOpBatch(f, op_name, batch_op_name, min_num_branches));
      };
  return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
}

But it seems not expose python interface.