Custom Pass is not working from tutorial

Hi All,

I’ve tried to run the CustomPass tutorial from here: https://docs.tvm.ai/tutorials/dev/relay_pass_infra.html

According to the tutorial, it is supposed to replace constant in the multiple with the constant that I give inside the CustomPass(e.g., replace 66 with 5555555). However, it looks like it is not working. Any ideas?


import numpy as np
import tvm
import tvm.relay as relay

###############################################################################
# Create An Example Relay Program


def example():
    shape = (1, 64, 54, 54)
    c_data = np.empty(shape).astype("float32")
    c = relay.const(c_data)
    weight = relay.var('weight', shape=(64, 64, 3, 3))
    x = relay.var("x", relay.TensorType((1, 64, 56, 56), "float32"))
    conv = relay.nn.conv2d(x, weight)
    y = relay.add(c, c)
    y = relay.multiply(y, relay.const(66, "float32"))
    y = relay.add(conv, y)
    z = relay.add(y, c)
    z1 = relay.add(y, c)
    z2 = relay.add(z, z1)
    return relay.Function([x], z2)


##############################################################################
# Implement a Pass Using Python Decorator

@relay.transform.function_pass(opt_level=1)
class CustomPipeline:
    """Simple test function to replace one argument to another."""

    def __init__(self, multiplier):
        self.multiplier = multiplier

    # This function can define a pass.
    def transform_function(self, func, mod, ctx):
        obj = self

        class ReplaceConstant(tvm.relay.ExprMutator):
            def visit_const(self, c):
                print("CCCCCCCCCCCCCCCCC")
                return relay.multiply(obj.multiplier, c)
        return ReplaceConstant().visit(func)

if __name__ == "__main__":
    print("Testing Custom Pass")
    f = example()
    mod = relay.Module.from_expr(f)
    print("Current module {}".format(mod))
    print("---------------------------")
    custom_pass = CustomPipeline(multiplier=relay.const(5555555, "float32"))
    assert custom_pass.info.name == "CustomPipeline"
    mod3 = custom_pass(mod)
    print(mod3)



Experiencing the same problem here. The tutorial on implementing a pass using python decorators does not seem to work. While transform_function is being called, its wrapping visit_const is not.

Any ideas what’s going on?

@zhiics could you take a look at this question?

This is a careless typo in the tutorial. We actually need visit_constant. A PR to fix it is welcome.

1 Like

Works! Also needed to change the multiplier constant type to float32 from float as the latter was auto-inferred to float64 on my system, which caused a type error:

custom_pass = CustomPipeline(multiplier=relay.const(3, "float32"))

@adi-muresan Thanks. Could you file a PR to fix the tutorial?

@zhiics done! It’s here.