Not sure if there are direct solutions, but in the meantime I was able to generate patterns using this expr visitor:
@relax.expr_functor.visitor
class PatternExtractor(relax.PyExprVisitor):
def visit_var_binding_(self, binding: relax.expr.VarBinding):
if isinstance(binding.value, relax.expr.Call):
call = binding.value
op_name = call.op.name
param_vars = [self.tensor_dict[arg.name_hint] for arg in call.args if isinstance(arg, relax.Var)]
self.tensor_dict[binding.var.name_hint] = is_op(op_name)(*param_vars)
elif isinstance(binding.value, relax.expr.DataflowVar):
self.tensor_dict[binding.var.name_hint] = self.tensor_dict[binding.value.name_hint]
super().visit_var_binding_(binding)
def visit_binding_block_(self, block: relax.expr.DataflowBlock):
for binding in block.bindings:
self.visit_binding(binding)
def extract(self, func):
self.tensor_dict = {param.name_hint: wildcard() for param in func.params}
self.output = func.body.body
self.visit_expr(func.body)
return self.tensor_dict[self.output.name_hint]