I am trying to rewrite sequence of ops to L2Norm using following code, I was able to match input and eps attribute but not able to match axis attribute, any suggestion here?
My intention is to match the pattern and get values form node_map and populate it in l2_normalize
class L2NormRewrite(DFPatternCallback):
def __init__(self):
super(L2NormRewrite, self).__init__()
self.data = wildcard()
self.norm = wildcard()
self.const = wildcard()
self.axis = wildcard()
pat = is_op("power")(self.data, self.norm)
pat = is_op("sum")(pat) # This works
# pat = is_op("sum")(pat, self.axis, wildcard()) # This doesn't work
# pat = is_op("sum").has_attr({'axis': self.axis})(pat) # This doesn't work
pat = is_op("sqrt")(pat)
pat = is_op("add")(pat, self.const)
pat = is_op("divide")(self.data, pat)
self.pattern = pat
def callback(self, pre, post, node_map):
data = node_map[self.data][0]
norm = node_map[self.norm][0]
const = node_map[self.const][0]
axisVal = node_map[self.axis][0]
# The values from nodemap will be populated to l2_normalize.
return tvm.relay.nn.l2_normalize(data, eps=const.data.numpy()[()], axis=axisVal)