Hello,
In the PassInfra Design and Development Doc the function_pass
decorator is briefly explained.
In the python codebase, there is also the class FunctionPass
, objects of this type should be created with function_pass
. It also states “works on each tvm.relay.Function in a module”
I have the following code:
import tvm
from tvm import relay
dtype = 'float32'
ishape = (1, 32, 14, 14)
wshape = (32, 32, 3, 3)
data1 = relay.var("data0", shape=ishape, dtype=dtype)
weight1 = relay.var("weight0", shape=wshape, dtype=dtype)
conv2d_1 = relay.nn.conv2d(data1, weight1, kernel_size=(3, 3), padding=(1, 1))
relu_1 = relay.nn.relu(conv2d_1)
f = relay.Function([data1, weight1], relu_1)
mod = tvm.IRModule()
mod['main'] = f
with tvm.transform.PassContext(opt_level=3):
opt_mod, _ = relay.optimize(mod, 'llvm') #I am only doing this to force the module to have many internal functions and function calls
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
"""Simple test function to replace one argument to another. After the first function"""
def __init__(self, new_func):
self.new_func = new_func
self.i = 0
def transform_function(self, func, mod, ctx):
if self.i == 0:
return func
else:
self.i =1
return self.new_func
fpass = TestReplaceFunc(f1)
print(opt_mod)
end_mod = fpass(opt_mod)
print(end_mod)
The output of both opt_mod
and end_mod
are identical and if I set a break point at the if self.i == 0:
statement, I see it only stops once.
The opt_mod
has many (internal) functions, but only has one global function (‘main’).
- Why is it not recursing through the rest of the internal functions?
If I do the same script as above but add another global function (ex:mod['not_main']=another_relay_function
) before optimizing, I also noticed that opt_mod
only has the ‘main’ function.
- What pass is deleting the ‘not_main’ function?
- there is no call to ‘not_main’ from ‘main’, so maybe dead code elimination? but I find it weird that it always preserves the ‘main’ function,
- Is there a way to recurse through all global functions?
- the
IRModule
has the attributefunctions
, but its a and I dont know how to get the keys or an iterator of them
- the
Thank you for the help