FunctionNode
just represents a relay function definition and a CallNode
is the caller of a function.
FunctionNode is used heavily in Relay fusion where you can fuse calls to multiple ops into a single Relay Function, which would get lowered to a single function in TIR and eventually in the backend.
So for example if you have a graph with 2 sets of conv2d → bias_add → relu calls and would like to fuse them, you could do that by grouping them into functions. Below is a sample code of how that would look along with the expected output
data = relay.var("data") : Var
weights = relay.var("weights") : Var
weights2 = relay.var("weights2") : Var
bias = relay.var("bias") : Var
bias2 = relay.var("bias2") : Var
conv2d1 = relay.nn.conv2d(data, weights) : Any bias_add1 = relay.nn.bias_add(conv2d1, bias) : Any relu1 = relay.nn.relu(bias_add1) : Any conv2d2 = relay.nn.conv2d(relu1, weights2) : Any bias_add2 = relay.nn.bias_add(conv2d2, bias2) : Any relu2 = relay.nn.relu(bias_add2) : Any
mod = tvm.IRModule() : IRModule
mod["main"] = relay.Function([data, weights, bias, weights2, bias2], relu2)
print("original_mod")
print("------------")
print(mod)
func1 = relay.Function([data, weights, bias], relu1) : Function
gvar1 = relay.GlobalVar("fused_conv2d_bias_add") : GlobalVar
func2 = relay.Function([weights2, bias2], relu2) : Function
gvar2 = relay.GlobalVar("fused_conv2d_bias_add_2") : GlobalVar
call1 = relay.Call(gvar1, [data, weights, bias]) : Call
call2 = relay.Call(gvar2, [call1, weights2, bias2]) : Call
print("fused_mod")
print("---------")
mod = tvm.IRModule({gvar1: func1, gvar2: func2}) : IRModule
mod["main"] = relay.Function([data, weights, bias, weights2, bias2], call2)
print(mod)
And the expected output would be:
original_mod
------------
def @main(%data, %weights, %bias, %weights2, %bias2) {
%0 = nn.conv2d(%data, %weights, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %bias);
%2 = nn.relu(%1);
%3 = nn.conv2d(%2, %weights2, padding=[0, 0, 0, 0]);
%4 = nn.bias_add(%3, %bias2);
nn.relu(%4)
}
fused_mod
---------
def @fused_conv2d_bias_add(%data, %weights, %bias) {
%0 = nn.conv2d(%data, %weights, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %bias);
nn.relu(%1)
}
def @fused_conv2d_bias_add_2(%weights2, %bias2) {
%2 = nn.relu(%1);
%3 = nn.conv2d(%2, %weights2, padding=[0, 0, 0, 0]);
%4 = nn.bias_add(%3, %bias2);
nn.relu(%4)
}
def @main(%data-malformed-ir, %weights-malformed-ir, %bias-malformed-ir, %weights2-malformed-ir, %bias2-malformed-ir) {
%5 = @fused_conv2d_bias_add(%data, %weights, %bias);
@fused_conv2d_bias_add_2(%5, %weights2, %bias2)
}