[SOLVED][Relax][Fuse related] Can i fuse several relax functions into one?

I am starting my work with a simple compute graph as follows:

class NNModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, v):
        x = nn.softmax(x)
        x = nn.matmul(x, v)

        return x

after invoking NNModule().export_tvm(...), i get the following module (the shape of softmax is [5,5], and the shape of matmul is [5,5] * [5,10] -> [5,10]):

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((5, 5), dtype="float32"), v: R.Tensor((5, 10), dtype="float32")) -> R.Tensor((5, 10), dtype="float32"):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            softmax: R.Tensor((5, 5), dtype="float32") = R.nn.softmax(x, axis=-1)
            matmul: R.Tensor((5, 10), dtype="float32") = R.matmul(softmax, v, out_dtype="void")
            gv: R.Tensor((5, 10), dtype="float32") = matmul
            R.output(gv)
        return gv

after relax.transform.LegalizeOps(), i get:

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def matmul(softmax: T.Buffer((T.int64(5), T.int64(5)), "float32"), v: T.Buffer((T.int64(5), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(5), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(5), T.int64(10), T.int64(5)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(softmax[v_i0, v_k], v[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + softmax[v_i0, v_k] * v[v_k, v_i1]

    @T.prim_func(private=True)
    def softmax(x: T.Buffer((T.int64(5), T.int64(5)), "float32"), T_softmax_norm: T.Buffer((T.int64(5), T.int64(5)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_softmax_maxelem = T.alloc_buffer((T.int64(5),))
        T_softmax_exp = T.alloc_buffer((T.int64(5), T.int64(5)))
        T_softmax_expsum = T.alloc_buffer((T.int64(5),))
        for i0, k in T.grid(T.int64(5), T.int64(5)):
            with T.block("T_softmax_maxelem"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(x[v_i0, v_k])
                T.writes(T_softmax_maxelem[v_i0])
                with T.init():
                    T_softmax_maxelem[v_i0] = T.float32(-340282346638528859811704183484516925440.0)
                T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], x[v_i0, v_k])
        for i0, i1 in T.grid(T.int64(5), T.int64(5)):
            with T.block("T_softmax_exp"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(x[v_i0, v_i1], T_softmax_maxelem[v_i0])
                T.writes(T_softmax_exp[v_i0, v_i1])
                T_softmax_exp[v_i0, v_i1] = T.exp(x[v_i0, v_i1] - T_softmax_maxelem[v_i0])
        for i0, k in T.grid(T.int64(5), T.int64(5)):
            with T.block("T_softmax_expsum"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(T_softmax_exp[v_i0, v_k])
                T.writes(T_softmax_expsum[v_i0])
                with T.init():
                    T_softmax_expsum[v_i0] = T.float32(0.0)
                T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k]
        for i0, i1 in T.grid(T.int64(5), T.int64(5)):
            with T.block("T_softmax_norm"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0])
                T.writes(T_softmax_norm[v_i0, v_i1])
                T.block_attr({"axis": 1})
                T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0]

    @R.function
    def forward(x: R.Tensor((5, 5), dtype="float32"), v: R.Tensor((5, 10), dtype="float32")) -> R.Tensor((5, 10), dtype="float32"):
        R.func_attr({"num_input": 2})
        cls = Module
        with R.dataflow():
            softmax = R.call_tir(cls.softmax, (x,), out_sinfo=R.Tensor((5, 5), dtype="float32"))
            matmul = R.call_tir(cls.matmul, (softmax, v), out_sinfo=R.Tensor((5, 10), dtype="float32"))
            gv: R.Tensor((5, 10), dtype="float32") = matmul
            R.output(gv)
        return gv

my question is can i fuse this softmax and matmul into one function and just keep their origin computations? what i want maybe like this:

...
def fused_softmax_matmul(x, v):
    # origin softmax computation, just like what shows above
    for ...
    for ...
    ...
    # following origin matmul computation, just like what shows above
    for ...
        for ...
            for ...
...
with R.dataflow():
    fused_softmax_matmul = R.call_tir(cls.fused_softmax_matmul, (x, v), ...)
    R.output(fused_softmax_matmul)
return fused_softmax_matmul
...

I tried FuseOpsByPattern, but i can just extract the patterns, without getting the actual computation functions. I am new to TVM so i cannot achieve this. Is it possible or how can i do this?

I made it by applying transforms like this:

relax.transform.AnnotateTIROpPattern(),
relax.transform.FoldConstant(),
# relax.transform.FuseOps(),
relax.transform.FuseOpsByPattern(
    patterns=[
        ('custom_fuse.softmax_gemm', is_op('relax.matmul')(is_op('relax.nn.softmax')(wildcard()), wildcard()))
    ],
    annotate_codegen=False,
    bind_constants=False,
),
relax.transform.LegalizeOps(),
relax.transform.FuseTIR(),
1 Like