Relax function name inconsistency?

Hi,

I am playing with relax and pytorch frontend, I have this IR imported with from_fx and I am wondering whether some function names are incorrect or not. For example in fused_relax_matmul11 a function lv1_1 is defined but never called. lv1 is called, but lv1 is not a symbol defined in fused_relax_matmul11. Am I missing something?

Here is the IRModule:

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def fused_relax_matmul11(lv4: R.Tensor((1, 512), dtype="float32"), lv5: R.Tensor((512, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
        R.func_attr({"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_matmul11"})
        with R.dataflow():
            # from tvm.script import relax as R
            
            @R.function
            def lv1_1(lv4_1: R.Tensor((1, 512), dtype="float32"), lv5_1: R.Tensor((512, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
                R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
                with R.dataflow():
                    gv: R.Tensor((1, 512), dtype="float32") = R.matmul(lv4_1, lv5_1, out_dtype="float32")
                    R.output(gv)
                return gv

            gv: R.Tensor((1, 512), dtype="float32") = lv1(lv4, lv5)
            R.output(gv)
        return gv

    @R.function
    def fused_relax_matmul21(lv8: R.Tensor((1, 512), dtype="float32"), lv9: R.Tensor((512, 10), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_matmul21"})
        with R.dataflow():
            # from tvm.script import relax as R
            
            @R.function
            def lv_1(lv8_1: R.Tensor((1, 512), dtype="float32"), lv9_1: R.Tensor((512, 10), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
                R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
                with R.dataflow():
                    gv: R.Tensor((1, 10), dtype="float32") = R.matmul(lv8_1, lv9_1, out_dtype="float32")
                    R.output(gv)
                return gv

            gv: R.Tensor((1, 10), dtype="float32") = lv(lv8, lv9)
            R.output(gv)
        return gv

    @R.function
    def fused_relax_matmul3(lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
        R.func_attr({"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_matmul3"})
        with R.dataflow():
            # from tvm.script import relax as R
            
            @R.function
            def lv2_1(lv_1: R.Tensor((1, 784), dtype="float32"), lv1_1: R.Tensor((784, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
                R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
                with R.dataflow():
                    gv: R.Tensor((1, 512), dtype="float32") = R.matmul(lv_1, lv1_1, out_dtype="float32")
                    R.output(gv)
                return gv

            gv: R.Tensor((1, 512), dtype="float32") = lv2(lv, lv1)
            R.output(gv)
        return gv

    @R.function
    def main(inp_0: R.Tensor((1, 1, 28, 28), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, 784), dtype="float32") = R.reshape(inp_0, R.shape([1, 784]))
            lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
            lv_1: R.Tensor((1, 512), dtype="float32") = cls.fused_relax_matmul3(lv, lv1)
            lv3: R.Tensor((1, 512), dtype="float32") = R.add(lv_1, metadata["relax.expr.Constant"][1])
            lv4: R.Tensor((1, 512), dtype="float32") = R.nn.relu(lv3)
            lv5: R.Tensor((512, 512), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
            lv1_1: R.Tensor((1, 512), dtype="float32") = cls.fused_relax_matmul11(lv4, lv5)
            lv7: R.Tensor((1, 512), dtype="float32") = R.add(lv1_1, metadata["relax.expr.Constant"][3])
            lv8: R.Tensor((1, 512), dtype="float32") = R.nn.relu(lv7)
            lv9: R.Tensor((512, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][4], axes=None)
            lv2: R.Tensor((1, 10), dtype="float32") = cls.fused_relax_matmul21(lv8, lv9)
            lv11: R.Tensor((1, 10), dtype="float32") = R.add(lv2, metadata["relax.expr.Constant"][5])
            gv: R.Tensor((1, 10), dtype="float32") = lv11
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

This is the script that lead there:

import torch
import torch.nn as nn
import torch.fx as fx

import tvm
from tvm import relax
import tvm.relax.frontend.torch
from tvm.relax.dpl import is_op, wildcard


class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


network = NeuralNetwork()
input_shape = (1, 1, 28, 28)
input_dtype = "float"

traced_network = fx.symbolic_trace(network)
input_info = [(input_shape, input_dtype)]
mod = relax.frontend.torch.from_fx(traced_network, input_info)

patterns = [
        ("tensorrt.matmul", is_op("relax.matmul")(wildcard(), wildcard())),
]
mod1 = relax.transform.FuseOpsByPattern(patterns)(mod)
mod2 = relax.transform.MergeCompositeFunctions()(mod1)
mod2.show()

Note: you need the latest unity to run the code as it was triggering a segfault until yesterday’s fix

Yeah I am aware of this issue and confused as well. I think it is only an issue of script printer - if you actually compile without print, it should always work.

cc @junrushao @cyx

1 Like

Ok, it’s good that it’s only on the printing side. Thank you.

By the way, I’ve been playing with dynamo_capture_subgraphs for getting an IRModule from pytorch as an alternative to from_fx and the entry function gets an unconventional name subgraph_0 instead of main which makes it incompatible with MergeCompositeFunctions that is looking for a main.

Is it intended or a bug?

dynamo_capture_subgraphs is a bit odd function. You need to do something like tvm.IRModule({"main": mod["subgraph_0"]}).

Note that PT 2.0 import story in Relax is currently not good primarily because PT 2.0 compile situation is not mature at the moment, especially wrt whole-graph export that many of TVM users care about but unfortunately not the first priority for Meta.

For one data point, in web-stable-diffusion the authors had to manually reimplement SD UNet just so that it can be traced by FX. The authors are the same people who developed the FX frontend in Relax.

2 Likes

Also CC: @Hzfengsy @spectrometerHBH

I see. Out of curiosity, are relax frontend actively developed? If so, what frontends are under active development?

I’d say frontend development in Relax is not active at the moment, because people care about only a few select models.

The FX importer in Relax was apparently developed for SD, but now that was done we now see only occasional PRs to the FX importer. And for LLM people write Relax modules by hand (see mlc-llm/mlc_llm/relax_model at main · mlc-ai/mlc-llm · GitHub), so there is no need for importer.

I’m aware that there is an effort to develop a StableHLO frontend https://github.com/apache/tvm/pull/14460, but it’s very early.

2 Likes

Frontends capabilities is one of the reasons to integrate new hardware with TVM. I will keep an eye on this. Thank you.