Thank @comaniac.
I have tried to use MergeCompilerRegion, and it is giving me an error with the following code. The following code works (I commented out MergeCompilerRegion), and it produces output with UNMERGED @special_ definitions. Ideally, i like to have one partition for the expressions in the true branch and I want to get another partition for the false branch.
def _register_external_op_helper(op_name, supported=True):
@tvm.ir.register_op_attr(op_name, "target.special")
def _func_wrapper(attrs, args):
return supported
return _func_wrapper
_register_external_op_helper("multiply")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
if graph_type == 1:
# this is test case for graph type 1
print("Graph type 1")
# graph 1: true branch
x1 = relay.var('x1', shape=(10, 1))
y1 = relay.var('y1', shape=(10, 1))
f1 = relay.op.multiply(x1, y1)
x3 = relay.var('x3', shape=(10, 1))
y3 = relay.var('y3', shape=(10, 1))
f3 = relay.op.multiply(x3, y3)
true_branch = relay.op.add(f1, f3)
# graph 2: false branch
x2 = relay.var('x2', shape=(10, 1))
y2 = relay.var('y2', shape=(10, 1))
f2 = relay.op.add(x2, y2)
x4 = relay.var('x4', shape=(10, 1))
y4 = relay.var('y4', shape=(10, 1))
f4 = relay.op.add(x4, y4)
false_branch = relay.op.add(f2, f4)
cond = relay.var('c')
result = relay.If(cond, true_branch=true_branch, false_branch=false_branch)
# f = relay.Function([], result)
f = relay.Function(relay.analysis.free_vars(result), result)
mod = tvm.IRModule({"main": f})
mod = relay.transform.AnnotateTarget(["special"])(mod) # Output: Figure 2
#mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.PartitionGraph()(mod) # Output: Figure 4
Here is the error that I get when I uncomment the MergeCompilerRegions function
Graph type 1
Traceback (most recent call last):
File "C:/repos/tvm23/tvm/graph_opt/subgraph/PartitionGraphTry.py", line 62, in <module>
mod = relay.transform.MergeCompilerRegions()(mod)
File "C:\repos\tvm23\tvm\python\tvm\ir\transform.py", line 127, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "C:\repos\tvm23\tvm\python\tvm\_ffi\_ctypes\packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: TVMError: Cannot find the corresponding region for end annotation:
#[version = "0.0.5"]
free_var %c: bool;
%0 = annotation.compiler_begin(%c, meta[relay.attrs.CompilerAttrs][0]) /* ty=bool */;
%25 = if (%0) {
free_var %x1: Tensor[(10, 1), float32];
%1 = annotation.compiler_begin(%x1, meta[relay.attrs.CompilerAttrs][1]) /* ty=Tensor[(10, 1), float32] */;
free_var %y1: Tensor[(10, 1), float32];
%2 = annotation.compiler_begin(%y1, meta[relay.attrs.CompilerAttrs][2]) /* ty=Tensor[(10, 1), float32] */;
%3 = multiply(%1, %2) /* ty=Tensor[(10, 1), float32] */;
%4 = annotation.compiler_end(%3, meta[relay.attrs.CompilerAttrs][3]) /* ty=Tensor[(10, 1), float32] */;
%5 = annotation.compiler_begin(%4, meta[relay.attrs.CompilerAttrs][4]) /* ty=Tensor[(10, 1), float32] */;
free_var %x3: Tensor[(10, 1), float32];
%6 = annotation.compiler_begin(%x3, meta[relay.attrs.CompilerAttrs][5]) /* ty=Tensor[(10, 1), float32] */;
free_var %y3: Tensor[(10, 1), float32];
%7 = annotation.compiler_begin(%y3, meta[relay.attrs.CompilerAttrs][6]) /* ty=Tensor[(10, 1), float32] */;
%8 = multiply(%6, %7) /* ty=Tensor[(10, 1), float32] */;
%9 = annotation.compiler_end(%8, meta[relay.attrs.CompilerAttrs][7]) /* ty=Tensor[(10, 1), float32] */;
%10 = annotation.compiler_begin(%9, meta[relay.attrs.CompilerAttrs][8]) /* ty=Tensor[(10, 1), float32] */;
%11 = add(%5, %10) /* ty=Tensor[(10, 1), float32] */;
%12 = annotation.compiler_end(%11, meta[relay.attrs.CompilerAttrs][9]) /* ty=Tensor[(10, 1), float32] */;
annotation.compiler_begin(%12, meta[relay.attrs.CompilerAttrs][10]) /* ty=Tensor[(10, 1), float32] */
} else {
free_var %x2: Tensor[(10, 1), float32];
%13 = annotation.compiler_begin(%x2, meta[relay.attrs.CompilerAttrs][11]) /* ty=Tensor[(10, 1), float32] */;
free_var %y2: Tensor[(10, 1), float32];
%14 = annotation.compiler_begin(%y2, meta[relay.attrs.CompilerAttrs][12]) /* ty=Tensor[(10, 1), float32] */;
%15 = add(%13, %14) /* ty=Tensor[(10, 1), float32] */;
%16 = annotation.compiler_end(%15, meta[relay.attrs.CompilerAttrs][13]) /* ty=Tensor[(10, 1), float32] */;
%17 = annotation.compiler_begin(%16, meta[relay.attrs.CompilerAttrs][14]) /* ty=Tensor[(10, 1), float32] */;
free_var %x4: Tensor[(10, 1), float32];
%18 = annotation.compiler_begin(%x4, meta[relay.attrs.CompilerAttrs][15]) /* ty=Tensor[(10, 1), float32] */;
free_var %y4: Tensor[(10, 1), float32];
%19 = annotation.compiler_begin(%y4, meta[relay.attrs.CompilerAttrs][16]) /* ty=Tensor[(10, 1), float32] */;
%20 = add(%18, %19) /* ty=Tensor[(10, 1), float32] */;
%21 = annotation.compiler_end(%20, meta[relay.attrs.CompilerAttrs][17]) /* ty=Tensor[(10, 1), float32] */;
%22 = annotation.compiler_begin(%21, meta[relay.attrs.CompilerAttrs][18]) /* ty=Tensor[(10, 1), float32] */;
%23 = add(%17, %22) /* ty=Tensor[(10, 1), float32] */;
%24 = annotation.compiler_end(%23, meta[relay.attrs.CompilerAttrs][19]) /* ty=Tensor[(10, 1), float32] */;
annotation.compiler_begin(%24, meta[relay.attrs.CompilerAttrs][20]) /* ty=Tensor[(10, 1), float32] */
};
annotation.compiler_end(%25, meta[relay.attrs.CompilerAttrs][21]) /* ty=Tensor[(10, 1), float32] */
/* For debugging purposes the metadata section has been omitted.
* If you would like to see the full metadata section you can set the
* option to `True` when invoking `astext`.
*/
Process finished with exit code 1