For example, if I have two “add” operators in my true and false branch, and I’d like to partition the true and false branches separately, can PartitionGraph() can help me?
This is exactly PartitionGraph does.
To me, it looks like ParitioGraph() seems limited because it partitions based on annotations that are attached to per operator kind.
This is because you only invoke AnnotateTarget
→ PartitionGraph
. There is another pass called MergeCompilerRegion
that removes unnecessary annotations, so you should go through AnnotateTarget
→ MergeCompilerRegion
→ PartitionGraph
.
The expected result of your example should be:
def @special_0(%special_0_i0: Tensor[(10, 1), float32], %special_0_i1: Tensor[(10, 1), float32], global_symbol="special_0", Primitive=1, Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
add(%special_0_i0, %special_0_i1) /* ty=Tensor[(10, 1), float32] */
}
def @special_1(%special_0_i0: Tensor[(10, 1), float32], %special_0_i1: Tensor[(10, 1), float32], global_symbol="special_0", Primitive=1, Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
multiply(%special_0_i0, %special_0_i1) /* ty=Tensor[(10, 1), float32] */
}
def @main(%c: bool, %x: Tensor[(10, 1), float32], %y: Tensor[(10, 1), float32], %x1: Tensor[(10, 1), float32], %y1: Tensor[(10, 1), float32]) -> Tensor[(10, 1), float32] {
if (%c) {
@special_0(%x, %y) /* ty=Tensor[(10, 1), float32] */
} else {
@special_1(%x1, %y1) /* ty=Tensor[(10, 1), float32] */
}
}
If it’s not, then we may have some issues/bugs to be fixed.