Pattern matching pass

Hi, I am trying to implement a simple pass that merges a certain pattern into a single “normalize” op. I use the register_pattern_table which I understand is support to register the pass.

def make_pattern():
    data = wildcard()
    data2 = wildcard()
    data3 = wildcard()

    copy_op =  is_op("copy")(data)
    reshape1_op = is_op("reshape")(data2)
    substract_op = is_op("subtract")(copy_op, reshape1_op)
    reshape2_op = is_op("reshape")(data3)
    divide_op = is_op("reshape")(substract_op, reshape2_op)
    return divide_op


@register_pattern_table
def normalize_pattern():
    normalize_pat = ("normalize", make_pattern())
    normalize_patterns = [normalize_pat]
    return normalize_patterns

I import this code in my test but nothing happens. the code is not called. what am I missing? It will really help if you could show me a tutorial or example of adding a pattern matching pass end-to-end.

1 Like

You need to use the MergeComposite pass. You can search its usage in our repo.