Hi All,
I’ve tried to run the CustomPass tutorial from here: https://docs.tvm.ai/tutorials/dev/relay_pass_infra.html
According to the tutorial, it is supposed to replace constant in the multiple with the constant that I give inside the CustomPass(e.g., replace 66 with 5555555). However, it looks like it is not working. Any ideas?
import numpy as np
import tvm
import tvm.relay as relay
###############################################################################
# Create An Example Relay Program
def example():
shape = (1, 64, 54, 54)
c_data = np.empty(shape).astype("float32")
c = relay.const(c_data)
weight = relay.var('weight', shape=(64, 64, 3, 3))
x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
conv = relay.nn.conv2d(x, weight)
y = relay.add(c, c)
y = relay.multiply(y, relay.const(66, "float32"))
y = relay.add(conv, y)
z = relay.add(y, c)
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
return relay.Function([x], z2)
##############################################################################
# Implement a Pass Using Python Decorator
@relay.transform.function_pass(opt_level=1)
class CustomPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, multiplier):
self.multiplier = multiplier
# This function can define a pass.
def transform_function(self, func, mod, ctx):
obj = self
class ReplaceConstant(tvm.relay.ExprMutator):
def visit_const(self, c):
print("CCCCCCCCCCCCCCCCC")
return relay.multiply(obj.multiplier, c)
return ReplaceConstant().visit(func)
if __name__ == "__main__":
print("Testing Custom Pass")
f = example()
mod = relay.Module.from_expr(f)
print("Current module {}".format(mod))
print("---------------------------")
custom_pass = CustomPipeline(multiplier=relay.const(5555555, "float32"))
assert custom_pass.info.name == "CustomPipeline"
mod3 = custom_pass(mod)
print(mod3)