[SOLVED][BYOC] Retrieving the output variable(s) of a function

How can I retrieve the output variables of a relax function?

I am working on a very simple C codegen for a new device that is currently allocating all intermediary tensors. I haven’t been able to differentiate between intermediary tensors and the output of my function so currently an extra tensor is allocated for the output, which is written on but never copied back to the host.

The pretty-printing makes these output visible, but I haven’t found an easy way to get them from within the code.

Here is a simple module example module as a basis for work:

# from tvm.script import ir as I
# from tvm.script import relax as R


@I.ir_module
class Module:
    @R.function
    def fused_relax_matmul_relax_matmul1(
        x: R.Tensor((128, 256), dtype="float32"),
        y: R.Tensor((256, 128), dtype="float32"),
    ) -> R.Tensor((128, 256), dtype="float32"):
        R.func_attr(
            {
                "Codegen": "foo",
                "Primitive": 1,
                "global_symbol": "fused_relax_matmul_relax_matmul1",
            }
        )
        with R.dataflow():
            # from tvm.script import relax as R

            @R.function
            def lv_1(
                x_1: R.Tensor((128, 256), dtype="float32"),
                y_1: R.Tensor((256, 128), dtype="float32"),
            ) -> R.Tensor((128, 128), dtype="float32"):
                R.func_attr({"Composite": "foo.matmul", "Primitive": 1})
                with R.dataflow():
                    gv: R.Tensor((128, 128), dtype="float32") = R.matmul(
                        x_1, y_1, out_dtype="void"
                    )
                    R.output(gv)
                return gv

            lv_1: R.Tensor((128, 128), dtype="float32") = lv(x, y)
            # from tvm.script import relax as R

            @R.function
            def lv1_1(
                lv0: R.Tensor((128, 128), dtype="float32"),
                x_1: R.Tensor((128, 256), dtype="float32"),
            ) -> R.Tensor((128, 256), dtype="float32"):
                R.func_attr({"Composite": "foo.matmul", "Primitive": 1})
                with R.dataflow():
                    gv: R.Tensor((128, 256), dtype="float32") = R.matmul(
                        lv0, x_1, out_dtype="void"
                    )
                    R.output(gv)
                return gv

            gv: R.Tensor((128, 256), dtype="float32") = lv1(lv_1, x)
            R.output(gv)
        return gv

    @R.function
    def main(
        x: R.Tensor((128, 256), dtype="float32"),
        y: R.Tensor((256, 128), dtype="float32"),
    ) -> R.Tensor((128, 256), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv: R.Tensor(
                (128, 256), dtype="float32"
            ) = cls.fused_relax_matmul_relax_matmul1(x, y)
            R.output(gv)
        return gv

the output of a function is simply its body, and then likely the body is SeqExpr, whose body is the output value

1 Like

Thank you! It seems like I overlooked the body member of SeqExpr.