Is there relay pass to convert divide to multiply?

Hi,

Divide is always a expensive op, for the dividend is constant, I think it is more reasonable to convert it to corresponding multiply one.

So my question is whether there is such pass already existed for this purpose? If not, how to create such pass?

Thx

may be you can write one, just using rewrite to rewrite divide to multiply.

import numpy as np
import tvm  # type: ignore
from tvm import relay
from tvm import ir
from tvm.relay.dataflow_pattern import DFPatternCallback
from tvm.relay.dataflow_pattern import wildcard
from tvm.relay.dataflow_pattern import is_op, is_constant
from tvm.relay.dataflow_pattern import rewrite
from tvm.relay import Expr


class LegalizeDivideRewriter(DFPatternCallback):

    def __init__(self, require_type=False, rewrite_once=False):
        super().__init__(require_type, rewrite_once)
        self.data = wildcard()
        self.weight = is_constant()
        self.add = is_op("divide")(self.data, self.weight)
        self.pattern = self.add

    def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr:
        # import pdb
        # pdb.set_trace()
        weight = relay.const(
            1.0 / node_map[self.weight][0].data.numpy(), dtype="float32")
        data = node_map[self.data][0]
        return relay.multiply(data, weight)


@ir.transform.module_pass(opt_level=1)
class LegalizeDivide:
    def transform_module(
        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
    ) -> tvm.ir.IRModule:
        for global_var, func in mod.functions.items():
            func = rewrite(LegalizeDivideRewriter(), func)
            mod.update_func(global_var, func)
        return mod

    def __call__(self, mod):
        return self.transform_module(mod)


shape = (10, )
data = relay.var(name_hint="data", shape=shape, dtype="float32")
weight = relay.const(np.random.uniform(0, 1, shape), dtype="float32")
div_out = relay.divide(data, weight)
mod = tvm.IRModule.from_expr(div_out)
mod = LegalizeDivide()(mod)
print(mod)

1 Like

Or you can also use ExprMutator

import numpy as np

from collections import Counter

import tvm
from tvm import relay
# from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor
from tvm.relay.expr_functor import ExprMutator, Call


class ReplaceOP(ExprMutator):
    def visit_call(self, call):
        new_fn = self.visit(call.op)
        args = []
        for arg in call.args:
            args.append(self.visit(arg))

        if str(call.op) == "divide" and isinstance(call.args[1], relay.expr.Constant):
            print("replacing divide to multiply")
            return relay.multiply(args[0], relay.const(1.0 / call.args[1].data.numpy()) )
        return Call(new_fn, args, call.attrs)

def test_replace_op():
    x = relay.var("x", shape=[1, 10])
    out = relay.divide(x, relay.const(0.5))
    expr = relay.Function([x, ], out)
    print(expr)
    '''
    fn (%x: Tensor[(1, 10), float32]) {
        divide(%x, 0.5f)
    }
    '''
    print(ReplaceOP().visit(expr))
    '''
    fn (%x: Tensor[(1, 10), float32]) {
       multiply(%x, 2f)
    }
    '''


if __name__ == "__main__":
    test_replace_op()
3 Likes

@Lyken17 @chenugray Thx for your kindly reply :wink: