Hi All,
Hi @comaniac, I want to follow up with my above post. I removed the IF statement, and now it works. Is that mean there is some MergeCompilerRegions does not fully support IF yet.
This is the code that works.
# this is test case for graph type 1
print("Graph type 1")
# graph 1: true branch
x1 = relay.var('x1', shape=(10, 1))
y1 = relay.var('y1', shape=(10, 1))
f1 = relay.op.multiply(x1, y1)
x3 = relay.var('x3', shape=(10, 1))
y3 = relay.var('y3', shape=(10, 1))
f3 = relay.op.multiply(x3, y3)
true_branch = relay.op.add(f1, f3)
# graph 2: false branch
x2 = relay.var('x2', shape=(10, 1))
y2 = relay.var('y2', shape=(10, 1))
f2 = relay.op.add(x2, y2)
x4 = relay.var('x4', shape=(10, 1))
y4 = relay.var('y4', shape=(10, 1))
f4 = relay.op.add(x4, y4)
false_branch = relay.op.add(f2, f4)
cond = relay.var('c')
#result = relay.If(cond, true_branch=true_branch, false_branch=false_branch)
result = true_branch
#f = relay.Function([], result)
f = relay.Function(relay.analysis.free_vars(result), result)
mod = tvm.IRModule({"main": f})
mod = relay.transform.AnnotateTarget(["special"])(mod) # Output: Figure 2
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.PartitionGraph()(mod) # Output: Figure 4
This is the CODE that DOES NOT work.
# this is test case for graph type 1
print("Graph type 1")
# graph 1: true branch
x1 = relay.var('x1', shape=(10, 1))
y1 = relay.var('y1', shape=(10, 1))
f1 = relay.op.multiply(x1, y1)
x3 = relay.var('x3', shape=(10, 1))
y3 = relay.var('y3', shape=(10, 1))
f3 = relay.op.multiply(x3, y3)
true_branch = relay.op.add(f1, f3)
# graph 2: false branch
x2 = relay.var('x2', shape=(10, 1))
y2 = relay.var('y2', shape=(10, 1))
f2 = relay.op.add(x2, y2)
x4 = relay.var('x4', shape=(10, 1))
y4 = relay.var('y4', shape=(10, 1))
f4 = relay.op.add(x4, y4)
false_branch = relay.op.add(f2, f4)
cond = relay.var('c')
result = relay.If(cond, true_branch=true_branch, false_branch=false_branch)
#result = true_branch
#f = relay.Function([], result)
f = relay.Function(relay.analysis.free_vars(result), result)
mod = tvm.IRModule({"main": f})
mod = relay.transform.AnnotateTarget(["special"])(mod) # Output: Figure 2
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.PartitionGraph()(mod) # Output: Figure 4