Understanding TVM/Relay's PartitionGraph()(mod) function

For example, if I have two “add” operators in my true and false branch, and I’d like to partition the true and false branches separately, can PartitionGraph() can help me?

This is exactly PartitionGraph does.

To me, it looks like ParitioGraph() seems limited because it partitions based on annotations that are attached to per operator kind.

This is because you only invoke AnnotateTargetPartitionGraph. There is another pass called MergeCompilerRegion that removes unnecessary annotations, so you should go through AnnotateTargetMergeCompilerRegionPartitionGraph.

The expected result of your example should be:

def @special_0(%special_0_i0: Tensor[(10, 1), float32], %special_0_i1: Tensor[(10, 1), float32], global_symbol="special_0", Primitive=1, Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
  add(%special_0_i0, %special_0_i1) /* ty=Tensor[(10, 1), float32] */
}
def @special_1(%special_0_i0: Tensor[(10, 1), float32], %special_0_i1: Tensor[(10, 1), float32], global_symbol="special_0", Primitive=1, Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
  multiply(%special_0_i0, %special_0_i1) /* ty=Tensor[(10, 1), float32] */
}
def @main(%c: bool, %x: Tensor[(10, 1), float32], %y: Tensor[(10, 1), float32], %x1: Tensor[(10, 1), float32], %y1: Tensor[(10, 1), float32]) -> Tensor[(10, 1), float32] {
  if (%c) {
    @special_0(%x, %y) /* ty=Tensor[(10, 1), float32] */
  } else {
    @special_1(%x1, %y1) /* ty=Tensor[(10, 1), float32] */
  }
}

If it’s not, then we may have some issues/bugs to be fixed.

1 Like