Hi,
I am playing with relax and pytorch frontend, I have this IR imported with from_fx
and I am wondering whether some function names are incorrect or not. For example in fused_relax_matmul11
a function lv1_1
is defined but never called. lv1
is called, but lv1
is not a symbol defined in fused_relax_matmul11
. Am I missing something?
Here is the IRModule:
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def fused_relax_matmul11(lv4: R.Tensor((1, 512), dtype="float32"), lv5: R.Tensor((512, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
R.func_attr({"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_matmul11"})
with R.dataflow():
# from tvm.script import relax as R
@R.function
def lv1_1(lv4_1: R.Tensor((1, 512), dtype="float32"), lv5_1: R.Tensor((512, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
with R.dataflow():
gv: R.Tensor((1, 512), dtype="float32") = R.matmul(lv4_1, lv5_1, out_dtype="float32")
R.output(gv)
return gv
gv: R.Tensor((1, 512), dtype="float32") = lv1(lv4, lv5)
R.output(gv)
return gv
@R.function
def fused_relax_matmul21(lv8: R.Tensor((1, 512), dtype="float32"), lv9: R.Tensor((512, 10), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_matmul21"})
with R.dataflow():
# from tvm.script import relax as R
@R.function
def lv_1(lv8_1: R.Tensor((1, 512), dtype="float32"), lv9_1: R.Tensor((512, 10), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
with R.dataflow():
gv: R.Tensor((1, 10), dtype="float32") = R.matmul(lv8_1, lv9_1, out_dtype="float32")
R.output(gv)
return gv
gv: R.Tensor((1, 10), dtype="float32") = lv(lv8, lv9)
R.output(gv)
return gv
@R.function
def fused_relax_matmul3(lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
R.func_attr({"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_matmul3"})
with R.dataflow():
# from tvm.script import relax as R
@R.function
def lv2_1(lv_1: R.Tensor((1, 784), dtype="float32"), lv1_1: R.Tensor((784, 512), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"):
R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 1})
with R.dataflow():
gv: R.Tensor((1, 512), dtype="float32") = R.matmul(lv_1, lv1_1, out_dtype="float32")
R.output(gv)
return gv
gv: R.Tensor((1, 512), dtype="float32") = lv2(lv, lv1)
R.output(gv)
return gv
@R.function
def main(inp_0: R.Tensor((1, 1, 28, 28), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv: R.Tensor((1, 784), dtype="float32") = R.reshape(inp_0, R.shape([1, 784]))
lv1: R.Tensor((784, 512), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
lv_1: R.Tensor((1, 512), dtype="float32") = cls.fused_relax_matmul3(lv, lv1)
lv3: R.Tensor((1, 512), dtype="float32") = R.add(lv_1, metadata["relax.expr.Constant"][1])
lv4: R.Tensor((1, 512), dtype="float32") = R.nn.relu(lv3)
lv5: R.Tensor((512, 512), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
lv1_1: R.Tensor((1, 512), dtype="float32") = cls.fused_relax_matmul11(lv4, lv5)
lv7: R.Tensor((1, 512), dtype="float32") = R.add(lv1_1, metadata["relax.expr.Constant"][3])
lv8: R.Tensor((1, 512), dtype="float32") = R.nn.relu(lv7)
lv9: R.Tensor((512, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][4], axes=None)
lv2: R.Tensor((1, 10), dtype="float32") = cls.fused_relax_matmul21(lv8, lv9)
lv11: R.Tensor((1, 10), dtype="float32") = R.add(lv2, metadata["relax.expr.Constant"][5])
gv: R.Tensor((1, 10), dtype="float32") = lv11
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
This is the script that lead there:
import torch
import torch.nn as nn
import torch.fx as fx
import tvm
from tvm import relax
import tvm.relax.frontend.torch
from tvm.relax.dpl import is_op, wildcard
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
network = NeuralNetwork()
input_shape = (1, 1, 28, 28)
input_dtype = "float"
traced_network = fx.symbolic_trace(network)
input_info = [(input_shape, input_dtype)]
mod = relax.frontend.torch.from_fx(traced_network, input_info)
patterns = [
("tensorrt.matmul", is_op("relax.matmul")(wildcard(), wildcard())),
]
mod1 = relax.transform.FuseOpsByPattern(patterns)(mod)
mod2 = relax.transform.MergeCompositeFunctions()(mod1)
mod2.show()
Note: you need the latest unity to run the code as it was triggering a segfault until yesterday’s fix