Hi! My goal is to add an if-else to a multiply op. I traverse the CallNode because multiply uses other op’s result, so in true branch, it will involve all the dataflow.
import tvm
from tvm import relay
from tvm.relay import transform
import numpy as np
from tvm.relay.backend import Executor, Runtime
class OpInIF(relay.ExprMutator):
def visit_call(self, call):
if call.op.name == ("multiply"):
# Wrap nn.conv2d in an if-else block
condition = relay.const(True, "bool")
# tmp = relay.RefCreate(call.args[0])
# new_call = relay.multiply(tmp, call.args[1])
then_branch = call
else_branch = call.args[0] # An empty tuple for the "else" branch
ife = relay.If(condition, then_branch, else_branch)
return ife
else:
return super().visit_call(call)
def visit_if(self, call):
if call.op.name == ("if"):
# Wrap nn.conv2d in an if-else block
ife = relay.var("expr")
return ife
else:
return super().visit_if(call)
# Define your original Relay function
def example():
shape = (1, 64, 54, 54)
c_data = np.empty(shape).astype("float32")
c = relay.const(c_data)
weight = relay.var("weight", shape=(64, 64, 3, 3))
x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
conv = relay.nn.conv2d(x, weight)
y = relay.add(c, c)
y = relay.multiply(y, relay.const(2, "float32"))
y = relay.add(conv, y)
z = relay.add(y, c)
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
return relay.Function([x, weight], z2)
f = example()
mod = tvm.IRModule.from_expr(f)
# Apply the transformation to the original function
print(mod)
passA = OpInIF()
transformed_function = passA.visit(mod["main"])
transformed_function = tvm.IRModule.from_expr(transformed_function)
#transformed_function = tvm.IRModule.from_expr(transformed_function)
# Print the transformed function
print("Transformed Function")
print(transformed_function)
I get result like:
def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {
%0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);
%1 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]);
%2 = multiply(%0, 2f);
%3 = add(%1, %2);
%4 = add(%3, meta[relay.Constant][0]);
%5 = add(%3, meta[relay.Constant][0]);
add(%4, %5)
}
Transformed Function
def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {
%1 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]);
%2 = if (True) {
%0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);
multiply(%0, 2f)
} else {
%0
};
%3 = add(%1, %2);
%4 = add(%3, meta[relay.Constant][0]);
%5 = add(%3, meta[relay.Constant][0]);
add(%4, %5)
}
So I assume that placing %0 = add outside if can be correct, but how can I achieve this?