Import RNN-T pytorch model into TVM

We tried to import RNN-T pytorch model https://github.com/mlperf/inference/tree/master/v0.7/speech_recognition/rnnt/pytorch into TVM.

Pre-trained RNN-T model for MLPerf Inference https://zenodo.org/record/3662521

We found the error:

NotImplementedError: The following operators are not implemented: [‘prim::RaiseException’, ‘prim::Uninitialized’, ‘aten::lstm’, ‘aten::_cast_Int’, ‘aten::__derive_index’, ‘aten::not’, ‘aten::tensor’, ‘aten::item’, ‘aten::cast_Float’, ‘prim::data’, ‘aten::format’, ‘aten::copy’, ‘aten::append’, ‘aten::FloatImplicit’, ‘prim::dtype’, ‘prim::shape’, ‘prim::TupleIndex’, ‘aten::dim’, ‘aten::warn’, ‘aten::is’, ‘prim::unchecked_cast’]

How can we go through these errors to make RNN-T running on TVM?

You are using torch.jit.script. Please try torch.jit.trace.

Also, I use torch.jit.trace. But, There are ‘aten::lstm’, ‘aten::copy_’ are not supported.

Traceback (most recent call last):
  File "/Automation/zzhen/pycharm-community-2019.3.3/plugins/python-ce/helpers/pydev/pydevd.py", line 1434, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Automation/zzhen/pycharm-community-2019.3.3/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/root/test6_root/zzhen/2020/SamsungGit/0911/quantization/tutorials/quantization/rnnt/quantize_rnnt.py", line 257, in <module>
    main()
  File "/root/test6_root/zzhen/2020/SamsungGit/0911/quantization/tutorials/quantization/rnnt/quantize_rnnt.py", line 244, in main
    mod, params = rnnt_model_to_tvm_mod(rnnt_model)
  File "/root/test6_root/zzhen/2020/SamsungGit/0911/quantization/tutorials/quantization/rnnt/quantize_rnnt.py", line 164, in rnnt_model_to_tvm_mod
    mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None)
  File "/root/test6_root/zzhen/2020/SamsungGit/0911/quantization/python/tvm/relay/frontend/pytorch.py", line 2547, in from_pytorch
    _report_missing_conversion(op_names, convert_map)
  File "/root/test6_root/zzhen/2020/SamsungGit/0911/quantization/python/tvm/relay/frontend/pytorch.py", line 2060, in _report_missing_conversion
    raise NotImplementedError(msg)
NotImplementedError: The following operators are not implemented: ['aten::lstm', 'aten::copy_']**strong text**

can you show me your script so that I can reproduce your problem?

The model of RNNT is from https://zenodo.org/record/3662521

sorry can you make a git repo with all necessary files?

Sorry, I send you the loading model part code into TVM. Our code can not be uploaded in the git repo. You only run this ‘import_rnnt.py’ , and it will reproduce the error. It needs a ‘rnnt.toml’ which is in the same path. You need to download the rnnt checkpoint by pytorch. https://zenodo.org/record/3662521

python import_rnnt.py --ckpt ‘path to rnnt.pt’

import_rnnt.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import toml
import torch
from typing import Optional, Tuple
import torchvision
import argparse
from tvm import relay
from tqdm import tqdm
import tvm
import numpy as np
import torch

def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt', type=str, required=True, help='The rnnt model path (pytorch ckpt path)')
    parser.add_argument("--model_toml", type=str, default='rnnt.toml',
                        help='relative model configuration path given dataset folder')
    args = parser.parse_args()
    return args


def rnn(rnn, input_size, hidden_size, num_layers, norm=None,
        forget_gate_bias=1.0, dropout=0.0, **kwargs):
    """TODO"""
    if rnn != "lstm":
        raise ValueError(f"Unknown rnn={rnn}")
    if norm not in [None]:
        raise ValueError(f"unknown norm={norm}")

    if rnn == "lstm":
        return LstmDrop(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            forget_gate_bias=forget_gate_bias,
            **kwargs
        )


class LstmDrop(torch.nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, dropout, forget_gate_bias,
                 **kwargs):
        """Returns an LSTM with forget gate bias init to `forget_gate_bias`.

        Args:
            input_size: See `torch.nn.LSTM`.
            hidden_size: See `torch.nn.LSTM`.
            num_layers: See `torch.nn.LSTM`.
            dropout: See `torch.nn.LSTM`.
            forget_gate_bias: For each layer and each direction, the total value of
                to initialise the forget gate bias to.

        Returns:
            A `torch.nn.LSTM`.
        """
        super(LstmDrop, self).__init__()

        # Interesting, torch LSTM allows specifying number of
        # layers... Fan-out parallelism.
        # WARNING: Is dropout repeated twice?
        self.lstm = torch.nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
        )
        if forget_gate_bias is not None:
            for name, v in self.lstm.named_parameters():
                if "bias_ih" in name:
                    bias = getattr(self.lstm, name)
                    bias.data[hidden_size:2 * hidden_size].fill_(forget_gate_bias)
                if "bias_hh" in name:
                    bias = getattr(self.lstm, name)
                    bias.data[hidden_size:2 * hidden_size].fill_(0)

        self.inplace_dropout = (torch.nn.Dropout(dropout, inplace=True)
                                if dropout else None)

    #zzhen developed
    def forward(self, x: torch.Tensor, h: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        x, h = self.lstm(x, h)
        if self.inplace_dropout is not None:
            self.inplace_dropout(x.data)

        return x, h


class StackTime(torch.nn.Module):
    def __init__(self, factor):
        super().__init__()
        self.factor = int(factor)

    def forward(self, x, x_lens):
        # T, B, U
        # x, x_lens = x
        seq = [x]
        for i in range(1, self.factor):
            # This doesn't seem to make much sense...
            tmp = torch.zeros_like(x)
            tmp[:-i, :, :] = x[i:, :, :]
            seq.append(tmp)
        x_lens = torch.ceil(x_lens.float() / self.factor).int()
        # Gross, this is horrible. What a waste of memory...
        return torch.cat(seq, dim=2)[::self.factor, :, :], x_lens


class RNNT(torch.nn.Module):
    """A Recurrent Neural Network Transducer (RNN-T).

    Args:
        in_features: Number of input features per step per batch.
        vocab_size: Number of output symbols (inc blank).
        forget_gate_bias: Total initialized value of the bias used in the
            forget gate. Set to None to use PyTorch's default initialisation.
            (See: http://proceedings.mlr.press/v37/jozefowicz15.pdf)
        batch_norm: Use batch normalization in encoder and prediction network
            if true.
        encoder_n_hidden: Internal hidden unit size of the encoder.
        encoder_rnn_layers: Encoder number of layers.
        pred_n_hidden:  Internal hidden unit size of the prediction network.
        pred_rnn_layers: Prediction network number of layers.
        joint_n_hidden: Internal hidden unit size of the joint network.
        rnn_type: string. Type of rnn in SUPPORTED_RNNS.
    """

    def __init__(self, rnnt=None, num_classes=1, **kwargs):
        super().__init__()
        if kwargs.get("no_featurizer", False):
            in_features = kwargs.get("in_features")
        else:
            feat_config = kwargs.get("feature_config")
            # This may be useful in the future, for MLPerf
            # configuration.
            in_features = feat_config['features'] * \
                feat_config.get("frame_splicing", 1)

        self._pred_n_hidden = rnnt['pred_n_hidden']

        self.encoder_n_hidden = rnnt["encoder_n_hidden"]
        self.encoder_pre_rnn_layers = rnnt["encoder_pre_rnn_layers"]
        self.encoder_post_rnn_layers = rnnt["encoder_post_rnn_layers"]

        self.pred_n_hidden = rnnt["pred_n_hidden"]
        self.pred_rnn_layers = rnnt["pred_rnn_layers"]

        self.encoder = Encoder(in_features,
                               rnnt["encoder_n_hidden"],
                               rnnt["encoder_pre_rnn_layers"],
                               rnnt["encoder_post_rnn_layers"],
                               rnnt["forget_gate_bias"],
                               None if "norm" not in rnnt else rnnt["norm"],
                               rnnt["rnn_type"],
                               rnnt["encoder_stack_time_factor"],
                               rnnt["dropout"],
                               )

        self.prediction = self._predict(
            num_classes,
            rnnt["pred_n_hidden"],
            rnnt["pred_rnn_layers"],
            rnnt["forget_gate_bias"],
            None if "norm" not in rnnt else rnnt["norm"],
            rnnt["rnn_type"],
            rnnt["dropout"],
        )

        self.joint_net = self._joint_net(
            num_classes,
            rnnt["pred_n_hidden"],
            rnnt["encoder_n_hidden"],
            rnnt["joint_n_hidden"],
            rnnt["dropout"],
        )

    def _predict(self, vocab_size, pred_n_hidden, pred_rnn_layers,
                 forget_gate_bias, norm, rnn_type, dropout):
        layers = torch.nn.ModuleDict({
            "embed": torch.nn.Embedding(vocab_size - 1, pred_n_hidden),
            "dec_rnn": rnn(
                rnn=rnn_type,
                input_size=pred_n_hidden,
                hidden_size=pred_n_hidden,
                num_layers=pred_rnn_layers,
                norm=norm,
                forget_gate_bias=forget_gate_bias,
                dropout=dropout,
            ),
        })
        return layers

    def _joint_net(self, vocab_size, pred_n_hidden, enc_n_hidden,
                   joint_n_hidden, dropout):
        layers = [
            torch.nn.Linear(pred_n_hidden + enc_n_hidden, joint_n_hidden),
            torch.nn.ReLU(),
        ] + ([torch.nn.Dropout(p=dropout), ] if dropout else []) + [
            torch.nn.Linear(joint_n_hidden, vocab_size)
        ]
        return torch.nn.Sequential(
            *layers
        )

    # Perhaps what I really need to do is provide a value for
    # state. But why can't I just specify a type for abstract
    # intepretation? That's what I really want!
    # We really want two "states" here...
    def forward(self, batch, state=None):
        # batch: ((x, y), (x_lens, y_lens))

        raise RuntimeError(
            "RNNT::forward is not currently used. "
            "It corresponds to training, where your entire output sequence "
            "is known before hand.")

        # x: TxBxF
        (x, y_packed), (x_lens, y_lens) = batch
        x_packed = torch.nn.utils.rnn.pack_padded_sequence(x, x_lens)

        f, x_lens = self.encode(x_packed)

        g, _ = self.predict(y_packed, state)
        out = self.joint(f, g)

        return out, (x_lens, y_lens)

    def predict(self, y, state=None, add_sos=True):
        """
        B - batch size
        U - label length
        H - Hidden dimension size
        L - Number of decoder layers = 2

        Args:
            y: (B, U)

        Returns:
            Tuple (g, hid) where:
                g: (B, U + 1, H)
                hid: (h, c) where h is the final sequence hidden state and c is
                    the final cell state:
                        h (tensor), shape (L, B, H)
                        c (tensor), shape (L, B, H)
        """
        if isinstance(y, torch.Tensor):
            y = self.prediction["embed"](y)
        elif isinstance(y, torch.nn.utils.rnn.PackedSequence):
            # Teacher-forced training mode
            # (B, U) -> (B, U, H)
            y._replace(data=self.prediction["embed"](y.data))
        else:
            # inference mode
            B = 1 if state is None else state[0].size(1)
            y = torch.zeros((B, 1, self.pred_n_hidden)).to(
                device=self.joint_net[0].weight.device,
                dtype=self.joint_net[0].weight.dtype
            )

        # preprend blank "start of sequence" symbol
        if add_sos:
            B, U, H = y.shape
            start = torch.zeros((B, 1, H)).to(device=y.device, dtype=y.dtype)
            y = torch.cat([start, y], dim=1).contiguous()   # (B, U + 1, H)
        else:
            start = None   # makes del call later easier

        y = y.transpose(0, 1)  # .contiguous()   # (U + 1, B, H)
        g, hid = self.prediction["dec_rnn"](y, state)
        g = g.transpose(0, 1)  # .contiguous()   # (B, U + 1, H)
        del y, start, state
        return g, hid


    def joint(self, f, g):
        """
        f should be shape (B, T, H)
        g should be shape (B, U + 1, H)

        returns:
            logits of shape (B, T, U, K + 1)
        """
        # Combine the input states and the output states
        B, T, H = f.shape
        B, U_, H2 = g.shape

        f = f.unsqueeze(dim=2)   # (B, T, 1, H)
        f = f.expand((B, T, U_, H))

        g = g.unsqueeze(dim=1)   # (B, 1, U + 1, H)
        g = g.expand((B, T, U_, H2))

        inp = torch.cat([f, g], dim=3)   # (B, T, U, 2H)
        res = self.joint_net(inp)
        del f, g, inp
        return res



class Encoder(torch.nn.Module):
    def __init__(self, in_features, encoder_n_hidden,
                 encoder_pre_rnn_layers, encoder_post_rnn_layers,
                 forget_gate_bias, norm, rnn_type, encoder_stack_time_factor,
                 dropout):
        super().__init__()
        self.pre_rnn = rnn(
            rnn=rnn_type,
            input_size=in_features,
            hidden_size=encoder_n_hidden,
            num_layers=encoder_pre_rnn_layers,
            norm=norm,
            forget_gate_bias=forget_gate_bias,
            dropout=dropout,
        )
        self.stack_time = StackTime(factor=encoder_stack_time_factor)
        self.post_rnn = rnn(
            rnn=rnn_type,
            input_size=encoder_stack_time_factor * encoder_n_hidden,
            hidden_size=encoder_n_hidden,
            num_layers=encoder_post_rnn_layers,
            norm=norm,
            forget_gate_bias=forget_gate_bias,
            norm_first_rnn=True,
            dropout=dropout,
        )

    def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
        x, _ = self.pre_rnn(x, None)
        x, x_lens = self.stack_time(x, x_lens)
        x, _ = self.post_rnn(x, None)
        x = x.transpose(0, 1)
        return x, x_lens


def get_rnnt_model(featurizer_config, model_definition, ctc_vocab, ckpt):
    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )
    checkpoint = torch.load(ckpt, map_location="cpu")
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()
    return model

def rnnt_model_to_tvm_mod(model):
    input_shape = (316, 1, 240)
    len_shape = (316)
    t_audio_signal_e = torch.randn(input_shape)
    t_a_sig_length_e = torch.randn(len_shape)
    model.encoder = torch.jit.trace(model.encoder, (t_audio_signal_e, t_a_sig_length_e)).eval()

    mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None)
    mod = relay.transform.RemoveUnusedFunctions()(mod)
    return mod, params

def add_blank_label(labels):
    if not isinstance(labels, list):
        raise ValueError("labels must be a list of symbols")
    labels.append("<BLANK>")
    return labels


def main():
    args = get_args()
    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)
    featurizer_config = model_definition['input_eval']

    rnnt_model = get_rnnt_model(featurizer_config, model_definition, ctc_vocab, args.ckpt)
    mod, params = rnnt_model_to_tvm_mod(rnnt_model)

if __name__ == '__main__':
    main()

rnnt.toml

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019, Myrtle Software Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

model = "RNNT"

[input]
normalize = "per_feature"
sample_rate = 16000
window_size = 0.02
window_stride = 0.01
window = "hann"
features = 80
n_fft = 512
frame_splicing = 3
dither = 0.00001
feat_type = "logfbank"
normalize_transcripts = true
trim_silence = true
pad_to = 0   # TODO
max_duration = 16.7
speed_perturbation = true


cutout_rect_regions = 0
cutout_rect_time = 60
cutout_rect_freq = 25


cutout_x_regions = 2
cutout_y_regions = 2
cutout_x_width = 6
cutout_y_width = 6


[input_eval]
normalize = "per_feature"
sample_rate = 16000
window_size = 0.02
window_stride = 0.01
window = "hann"
features = 80
n_fft = 512
frame_splicing = 3
dither = 0.00001
feat_type = "logfbank"
normalize_transcripts = true
trim_silence = true
pad_to = 0


[rnnt]
rnn_type = "lstm"
encoder_n_hidden = 1024
encoder_pre_rnn_layers = 2
encoder_stack_time_factor = 2
encoder_post_rnn_layers = 3
pred_n_hidden = 320
pred_rnn_layers = 2
forget_gate_bias = 1.0
joint_n_hidden = 512
dropout=0.32


[labels]
labels = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]

Ok, I was able to reproduce the issue. It seems supporting aten::lstm is complicated and I’m not an expert on LSTM. I created an issue https://github.com/apache/incubator-tvm/issues/6474 to ask for a help.

For now, I recommend exporting the model to ONNX, and use our ONNX frontend, since it has support for ONNX LSTM op.