How can I retrieve the output variables of a relax function?
I am working on a very simple C codegen for a new device that is currently allocating all intermediary tensors. I haven’t been able to differentiate between intermediary tensors and the output of my function so currently an extra tensor is allocated for the output, which is written on but never copied back to the host.
The pretty-printing makes these output visible, but I haven’t found an easy way to get them from within the code.
Here is a simple module example module as a basis for work:
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def fused_relax_matmul_relax_matmul1(
x: R.Tensor((128, 256), dtype="float32"),
y: R.Tensor((256, 128), dtype="float32"),
) -> R.Tensor((128, 256), dtype="float32"):
R.func_attr(
{
"Codegen": "foo",
"Primitive": 1,
"global_symbol": "fused_relax_matmul_relax_matmul1",
}
)
with R.dataflow():
# from tvm.script import relax as R
@R.function
def lv_1(
x_1: R.Tensor((128, 256), dtype="float32"),
y_1: R.Tensor((256, 128), dtype="float32"),
) -> R.Tensor((128, 128), dtype="float32"):
R.func_attr({"Composite": "foo.matmul", "Primitive": 1})
with R.dataflow():
gv: R.Tensor((128, 128), dtype="float32") = R.matmul(
x_1, y_1, out_dtype="void"
)
R.output(gv)
return gv
lv_1: R.Tensor((128, 128), dtype="float32") = lv(x, y)
# from tvm.script import relax as R
@R.function
def lv1_1(
lv0: R.Tensor((128, 128), dtype="float32"),
x_1: R.Tensor((128, 256), dtype="float32"),
) -> R.Tensor((128, 256), dtype="float32"):
R.func_attr({"Composite": "foo.matmul", "Primitive": 1})
with R.dataflow():
gv: R.Tensor((128, 256), dtype="float32") = R.matmul(
lv0, x_1, out_dtype="void"
)
R.output(gv)
return gv
gv: R.Tensor((128, 256), dtype="float32") = lv1(lv_1, x)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((128, 256), dtype="float32"),
y: R.Tensor((256, 128), dtype="float32"),
) -> R.Tensor((128, 256), dtype="float32"):
cls = Module
with R.dataflow():
gv: R.Tensor(
(128, 256), dtype="float32"
) = cls.fused_relax_matmul_relax_matmul1(x, y)
R.output(gv)
return gv