[PassInfra] Relay Function Pass

Hello,

In the PassInfra Design and Development Doc the function_pass decorator is briefly explained. In the python codebase, there is also the class FunctionPass, objects of this type should be created with function_pass. It also states “works on each tvm.relay.Function in a module”

I have the following code:

import tvm
from tvm import relay

dtype = 'float32'
ishape = (1, 32, 14, 14)
wshape = (32, 32, 3, 3)
data1 = relay.var("data0", shape=ishape, dtype=dtype)
weight1 = relay.var("weight0", shape=wshape, dtype=dtype)
conv2d_1 = relay.nn.conv2d(data1, weight1, kernel_size=(3, 3), padding=(1, 1))
relu_1 = relay.nn.relu(conv2d_1)
f = relay.Function([data1, weight1], relu_1)
mod = tvm.IRModule()
mod['main'] = f

with tvm.transform.PassContext(opt_level=3):
    opt_mod, _ = relay.optimize(mod, 'llvm') #I am only doing this to force the module to have many internal functions and function calls

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
    """Simple test function to replace one argument to another. After the first function"""
    def __init__(self, new_func):
        self.new_func = new_func
        self.i = 0
    def transform_function(self, func, mod, ctx):
        if self.i == 0:
            return func
        else:
            self.i =1
            return self.new_func

fpass = TestReplaceFunc(f1)
print(opt_mod)
end_mod = fpass(opt_mod)
print(end_mod)

The output of both opt_mod and end_mod are identical and if I set a break point at the if self.i == 0:statement, I see it only stops once. The opt_mod has many (internal) functions, but only has one global function (‘main’).

  • Why is it not recursing through the rest of the internal functions?

If I do the same script as above but add another global function (ex:mod['not_main']=another_relay_function) before optimizing, I also noticed that opt_mod only has the ‘main’ function.

  • What pass is deleting the ‘not_main’ function?
    • there is no call to ‘not_main’ from ‘main’, so maybe dead code elimination? but I find it weird that it always preserves the ‘main’ function,
  • Is there a way to recurse through all global functions?
    • the IRModule has the attribute functions, but its a and I dont know how to get the keys or an iterator of them

Thank you for the help :slight_smile:

Hello, I also meet similar questions, How did you solute these questions? Could you give me some advice? :grinning: