[MSC] Translate relay to relax without loss info

Why do this?

After I spend some time play with relax, I found out that relax has quite different structure with relay. I talked with some developers and my colleagues and found out that some important features are developed base on relay, so does my works. That means using relax as the new infrastructure may cost some time in re-implementing those features. A solution I came up is to translate relay to relax without loss info.

How to solve it currently?

A common solution is show in test:

This example use relay_translator.from_relay to translate relay to relax. This solution lower the relay model to tir functions, which loss the information of optype and attributes. Missing such information may make the model optimization hard to be developed, because most algorithm need specified info about each node. So that example solve the problems in model inference, but left some problems unsolved for model deployment and optimization.

How to solve it by MSC(Multi-system compiler)?

The best solution, of course, is to implement all the relay based features into relax. I know developers in community are working on that and I’m sure the final solution is on it’s way. I just want to show another way to solve this. Based on MSC (see [RFC][Unity][MSC] Introduction to Multi-System Compiler), I map relay to MSCGraph, which has a DAG structure and save all the information, and codegen MSCGraph to relax. In that way the Function can be easily translate, while the cost is quite a little: just a implment of codegen. As MSC is not mainly designed for this task, I may not spend lot time testing different cases. I just found out that design of MSC can do the translation with a very low cost. And it enable the relax with relay based features in a simple way. If anyone interested in this, please let me know. Some cases are tested in the example under tvm/tests/python/contrib/test_msc/test_translate_relay.py (branch unity)

Example of translating torch.nn.Linear

import numpy as np
import torch
from torch import fx
from torch.nn import Module

import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.relay.frontend import from_pytorch
from tvm.contrib.msc.core.ir import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen


def get_model():
    class Dense(Module):
        def __init__(self):
            super().__init__()
            self.linear = torch.nn.Linear(10, 7, bias=True)

        def forward(self, input):
            return self.linear(input)

    input_info = [([1, 3, 10, 10], "float32")]
    return Dense(), input_info


if __name__ == "__main__":
    model, input_info = get_model()
    graph_model = fx.symbolic_trace(model)
    with torch.no_grad():
        expected = from_fx(graph_model, input_info)
    print("expected " + str(expected))

    # graph from relay
    input_datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
    t_inputs = [torch.from_numpy(i) for i in input_datas]
    scripted_model = torch.jit.trace(model, tuple(t_inputs)).eval()  # type: ignore
    shape_list = [("input" + str(idx), i) for idx, i in enumerate(input_info)]
    relay_mod, params = from_pytorch(scripted_model, shape_list)
    print("relay " + str(relay_mod["main"]))

    graph, weights = translate.from_relay(relay_mod, params)
    print("graph " + str(graph))

    # to relax
    codegen_config = {"explicit_name": False, "from_relay": True}
    mod = tvm_codegen.to_relax(graph, weights, codegen_config)
    print("relax " + str(mod))

    tvm.ir.assert_structural_equal(mod, expected)

Result:

expected
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
            lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
            lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, metadata["relax.expr.Constant"][1])
            gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv2
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

relay
fn (%input0: Tensor[(1, 3, 10, 10), float32] /* span=aten::linear_0.input0:0:0 */, %aten::linear_0.weight: Tensor[(7, 10), float32] /* span=aten::linear_0.weight:0:0 */, %aten::linear_0.bias: Tensor[(7), float32] /* span=aten::linear_0.bias:0:0 */) {
  %0 = broadcast_to(%input0, shape=[1, 3, 10, 10]) /* span=aten::linear_0:0:0 */;
  %1 = transpose(%aten::linear_0.weight, axes=[1, 0]) /* span=aten::linear_0:0:0 */;
  %2 = broadcast_to(%1, shape=[1, 3, 10, 7]) /* span=aten::linear_0:0:0 */;
  %3 = reshape(%0, newshape=[-1, 10, 10]) /* span=aten::linear_0:0:0 */;
  %4 = reshape(%2, newshape=[-1, 10, 7]) /* span=aten::linear_0:0:0 */;
  %5 = nn.batch_matmul(%3, %4) /* span=aten::linear_0:0:0 */;
  %6 = reshape(%5, newshape=[1, 3, 10, 7]) /* span=aten::linear_0:0:0 */;
  %7 = squeeze(%6, axis=[]) /* span=aten::linear_0:0:0 */;
  nn.bias_add(%7, %aten::linear_0.bias, axis=-1) /* span=aten::linear_0:0:0 */
}

graph
main <INPUTS: input0:0| OUTPUTS: msc.linear_bias:0>
ID_0 input0 <PARENTS: | CHILDERN: msc.linear_bias>
  OUT: input0:0(input0)<1,3,10,10|float32>
  OPTYPE: input

ID_1 msc.linear_bias <PARENTS: input0| CHILDERN: >
  IN: input0:0(input0)<1,3,10,10|float32>
  OUT: msc.linear_bias:0(msc.linear_bias_0)<1,3,10,7|float32>
  OPTYPE: msc.linear_bias
  SCOPE: block
  ATTRS: allowzero_1=0 out_dtype= transpose_a=0 allowzero=0 axis= dtype_1= axes=1,0 dtype= shape_1=1,3,10,7 newshape_1=-1,10,7 transpose_b=0 allowzero_2=0 shape=1,3,10,10 newshape_2=1,3,10,7 axis_1=-1 newshape=-1,10,10
  WEIGHTS:
    weight: const<7,10|float32>
    bias: const_1<7|float32>

relax 
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(input0: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
            lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(input0, lv, out_dtype="float32")
            lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, metadata["relax.expr.Constant"][1])
            gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv2
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
2 Likes