How to add a relay pass

From the tvm docs, there are two ways to implement a pass: Using Python Decorator or Use C++. I found most relay passes are implemented by C++, such as fold_constant, etc,. Only one pass named ChangeBatch(https://github.com/apache/tvm/blob/main/python/tvm/relay/transform/transform.py#L1062) is impled using Python Decorator.

So if I want to add a custom relay pass, is it better to implemented it by C++?

both implementions is ok. in most condition, writing python pass is enough.

But there are few examples of passes writing by python. Only one pass named ChangeBatch is writen by python. So it may be hard for pass beginers.

BTY, if I want to insert a op to the relay graph ir, I should use module-level pass?

Hi, @zhaoyang-star.

I believe that is because Python interface is more recommended for the quick tryout, rather than the formal integration. If you are planning to push your pass to the main branch, I think C++ implementation is generally recommended.

You can either use a module-level pass or a relay function-level pass. But since a function-level pass applies transformation for each relay function independently, you would need to use a module-level pass if you want to make change across functions.

1 Like

Thanks for your kind reply.

@zhaoyang-star There is another pass implemented in python,
the example is here: /path_to_tvm/gallery/how_to/extend_tvm/use_pass_infra.py

@relay.transform.function_pass(opt_level=1)
class CustomPipeline:
    """Simple test function to replace one argument to another."""

    def __init__(self, multiplier):
        self.multiplier = multiplier

    # This function can define a pass.
    def transform_function(self, func, mod, ctx):
        obj = self

        class ReplaceConstant(tvm.relay.ExprMutator):
            def visit_constant(self, c):
                return relay.multiply(obj.multiplier, c)

        return ReplaceConstant().visit(func)


f = example()
mod = tvm.IRModule.from_expr(f)
custom_pass = CustomPipeline(multiplier=relay.const(3, "float32"))
assert custom_pass.info.name == "CustomPipeline"
mod3 = custom_pass(mod)
print(mod3)