Is there relay pass to convert divide to multiply?

may be you can write one, just using rewrite to rewrite divide to multiply.

import numpy as np
import tvm  # type: ignore
from tvm import relay
from tvm import ir
from tvm.relay.dataflow_pattern import DFPatternCallback
from tvm.relay.dataflow_pattern import wildcard
from tvm.relay.dataflow_pattern import is_op, is_constant
from tvm.relay.dataflow_pattern import rewrite
from tvm.relay import Expr


class LegalizeDivideRewriter(DFPatternCallback):

    def __init__(self, require_type=False, rewrite_once=False):
        super().__init__(require_type, rewrite_once)
        self.data = wildcard()
        self.weight = is_constant()
        self.add = is_op("divide")(self.data, self.weight)
        self.pattern = self.add

    def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr:
        # import pdb
        # pdb.set_trace()
        weight = relay.const(
            1.0 / node_map[self.weight][0].data.numpy(), dtype="float32")
        data = node_map[self.data][0]
        return relay.multiply(data, weight)


@ir.transform.module_pass(opt_level=1)
class LegalizeDivide:
    def transform_module(
        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
    ) -> tvm.ir.IRModule:
        for global_var, func in mod.functions.items():
            func = rewrite(LegalizeDivideRewriter(), func)
            mod.update_func(global_var, func)
        return mod

    def __call__(self, mod):
        return self.transform_module(mod)


shape = (10, )
data = relay.var(name_hint="data", shape=shape, dtype="float32")
weight = relay.const(np.random.uniform(0, 1, shape), dtype="float32")
div_out = relay.divide(data, weight)
mod = tvm.IRModule.from_expr(div_out)
mod = LegalizeDivide()(mod)
print(mod)

1 Like