How can I add right control flow to relay's dataflow?

Hi! My goal is to add an if-else to a multiply op. I traverse the CallNode because multiply uses other op’s result, so in true branch, it will involve all the dataflow.

import tvm
from tvm import relay
from tvm.relay import transform
import numpy as np
from tvm.relay.backend import Executor, Runtime

class OpInIF(relay.ExprMutator):
    def visit_call(self, call):
        if call.op.name == ("multiply"):
            # Wrap nn.conv2d in an if-else block
            condition = relay.const(True, "bool")
            # tmp = relay.RefCreate(call.args[0])
            # new_call = relay.multiply(tmp, call.args[1])
            then_branch = call
            else_branch = call.args[0]  # An empty tuple for the "else" branch
            ife = relay.If(condition, then_branch, else_branch)
            return ife
        else:
          return super().visit_call(call)

    def visit_if(self, call):
        if call.op.name == ("if"):
            # Wrap nn.conv2d in an if-else block
            ife = relay.var("expr")
            return ife
        else:
          return super().visit_if(call)


# Define your original Relay function
def example():
    shape = (1, 64, 54, 54)
    c_data = np.empty(shape).astype("float32")
    c = relay.const(c_data)
    weight = relay.var("weight", shape=(64, 64, 3, 3))
    x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
    conv = relay.nn.conv2d(x, weight)
    y = relay.add(c, c)
    y = relay.multiply(y, relay.const(2, "float32"))
    y = relay.add(conv, y)
    z = relay.add(y, c)
    z1 = relay.add(y, c)
    z2 = relay.add(z, z1)
    return relay.Function([x, weight], z2)

f = example()
mod = tvm.IRModule.from_expr(f)

# Apply the transformation to the original function

print(mod)
passA = OpInIF()

transformed_function = passA.visit(mod["main"])
transformed_function = tvm.IRModule.from_expr(transformed_function)

#transformed_function = tvm.IRModule.from_expr(transformed_function)
# Print the transformed function
print("Transformed Function")
print(transformed_function)

I get result like:

def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {
  %0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);
  %1 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]);
  %2 = multiply(%0, 2f);
  %3 = add(%1, %2);
  %4 = add(%3, meta[relay.Constant][0]);
  %5 = add(%3, meta[relay.Constant][0]);
  add(%4, %5)
}


Transformed Function
def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {
  %1 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]);
  %2 = if (True) {
    %0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);
    multiply(%0, 2f)
  } else {
    %0
  };
  %3 = add(%1, %2);
  %4 = add(%3, meta[relay.Constant][0]);
  %5 = add(%3, meta[relay.Constant][0]);
  add(%4, %5)
}

So I assume that placing %0 = add outside if can be correct, but how can I achieve this?