Defining block and matching pattern at the same time

I would like to define my model with logical block of layers (like an MLP for example) as nn.Module, like in Optimize LLM example, then extract these logical blocks into functions for high level analysis.

Currently the way to do it is for each block written I have to write a matching pattern.

Is there a way to write the code once and being able to generate both from it since it is the same logic?

I am currently using the relax nn frontend to write the blocks

Not sure if there are direct solutions, but in the meantime I was able to generate patterns using this expr visitor:

@relax.expr_functor.visitor
class PatternExtractor(relax.PyExprVisitor):
    def visit_var_binding_(self, binding: relax.expr.VarBinding):
        if isinstance(binding.value, relax.expr.Call):
            call = binding.value
            op_name = call.op.name
            param_vars = [self.tensor_dict[arg.name_hint] for arg in call.args if isinstance(arg, relax.Var)]
            self.tensor_dict[binding.var.name_hint] = is_op(op_name)(*param_vars)
        elif isinstance(binding.value, relax.expr.DataflowVar):
            self.tensor_dict[binding.var.name_hint] = self.tensor_dict[binding.value.name_hint]
            
        super().visit_var_binding_(binding)

    def visit_binding_block_(self, block: relax.expr.DataflowBlock):
        for binding in block.bindings:
            self.visit_binding(binding)

    def extract(self, func):
        self.tensor_dict = {param.name_hint: wildcard() for param in func.params}
        self.output = func.body.body
        self.visit_expr(func.body)
        return self.tensor_dict[self.output.name_hint]