[Question] How TVM run text generation model like gpt2

If we run text generation using gpt2 from huggingface, the seq_len increase each time because it is an autoregressive model. In other words, the model inputs is dynamic, can TVM support such text generation case?

Relax supports dynamic shape by design, so it’s highly recommended to check it out

Hey, thanks for your reply. I wonder where can we find the official docs and demos about Relax and when can we use Relax in TVM main branch?

As far as I know, Relax has not released by now. How to use this feature?

The script below is to convert GPT2 to a Torch Scipted model and verify it both by using PyTorch inference and TVM inference. PyTorch inference is all right while TVM inference occurs error.

Error KeyError: 'default_token.1' occured when run relay.frontend.from_pytorch. I printed out the PyTorch graph as following and found defalut_token in the graph. May I need to add some custom op conversation? @masahi

Thanks~~

graph(%self.1 : __torch__.FinishMySentence,
      %x.1 : Tensor):
  %34 : bool = prim::Constant[value=1]()
  %31 : int = prim::Constant[value=1]()
  %29 : NoneType = prim::Constant()
  %12 : int = prim::Constant[value=9223372036854775807]() 
  %25 : int = prim::Constant[value=0]() 
  %26 : int = prim::Constant[value=-1]() 
  %default_token.1 : Tensor = prim::GetAttr[name="default_token"](%self.1)
  %eos.1 : Tensor = prim::GetAttr[name="eos"](%self.1)
  %9 : Tensor = aten::ne(%default_token.1, %eos.1)
  %11 : bool = aten::Bool(%9) 
  %sentence : Tensor = prim::Loop(%12, %11, %x.1)
    block0(%15 : int, %sentence.9 : Tensor):
      %next_token_predictor.1 : __torch__.transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel = prim::GetAttr[name="next_token_predictor"](%self.1)
      %20 : (Tensor, ((Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor))) = prim::CallMethod[name="forward"](%next_token_predictor.1, %sentence.9) # demo2.py:23:29
      %predictions.1 : Tensor, %23 : ((Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor), (Tensor, Tensor)) = prim::TupleUnpack(%20)
      %27 : Tensor = aten::select(%predictions.1, %25, %26) 
      %32 : Tensor = aten::slice(%27, %25, %29, %29, %31) 
      %token.1 : Tensor = aten::argmax(%32, %25, %34) 
      %38 : Tensor[] = prim::ListConstruct(%sentence.9, %token.1)
      %sentence0.1 : Tensor = aten::cat(%38, %25) 
      %eos : Tensor = prim::GetAttr[name="eos"](%self.1)
      %45 : Tensor = aten::ne(%token.1, %eos) 
      %47 : bool = aten::Bool(%45) 
      -> (%47, %sentence0.1)
  return (%sentence)
import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer


class FinishMySentence(torch.nn.Module):
    def __init__(self, model=None, eos=198):
        super(FinishMySentence, self).__init__()
        self.eos = torch.tensor([eos])
        self.next_token_predictor = model
        self.default_token = torch.tensor([0])

    def forward(self, x):
        sentence = x
        token = self.default_token
        while token != self.eos:
            predictions, _ = self.next_token_predictor(sentence)
            token = torch.argmax(predictions[-1, :], dim=0, keepdim=True)
            sentence = torch.cat((sentence, token), 0)

        return sentence


# Convert to scripted model
token_predictor = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

# trace
random_tokens = torch.randint(10000, (5,))
traced_token_predictor = torch.jit.trace(token_predictor, random_tokens)
torch.jit.save(traced_token_predictor, "traced_gpt2.pt")

# script
model = FinishMySentence(model=traced_token_predictor)
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, "scripted_gpt2.pt")


# Use PyTorch inference
sentence_fragment = "The Manhattan bridge is a major"

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
context = torch.tensor(tokenizer.encode(sentence_fragment))

# torch_out = scripted_model(context)
loaded_model = torch.jit.load("scripted_gpt2.pt").eval()
print(loaded_model.graph)
torch_out = loaded_model(context)
generated_text_torch = tokenizer.decode(torch_out)
print("Fragment: {}".format(sentence_fragment))
print("Completed: {}".format(generated_text_torch))


# Use TVM
import tvm
from tvm import relay

inputs = [("dummy_input_name", (5,))]
mod, params = relay.frontend.from_pytorch(loaded_model, inputs)
print(mod)

You should use torch.jit.trace.

Thanks for your kind reply. Sorry for my unfamiliar with Torch, I still cann’t get the point.

As there is while loop in model so I used jit.script to generate TorchScript model named scripted_gpt2.pt.

Could you give me more hint? @masahi

ok it seems you are trying to convert the decoding step (FinishMySentence) of text generation. This is not a learned component, so there is little point in converting it to TVM.

On the other hand, GPT2LMHeadModel should be converted to TVM without issues (what you call traced_token_predictor). If not, I can take a look.

There is an error occured after changed to mod, params = relay.frontend.from_pytorch(traced_token_predictor, inputs, default_dtype="int64")

The value type is tvm.relay.expr.Call while expected to be scalar or NDArray.

Traceback (most recent call last):
  File "test1.py", line 58, in <module>
    mod, params = relay.frontend.from_pytorch(traced_token_predictor, inputs)
  File "/WORK/Dev/tvm/python/tvm/relay/frontend/pytorch.py", line 4173, in from_pytorch
    outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
  File "/WORK/Dev/tvm/python/tvm/relay/frontend/pytorch.py", line 3547, in convert_operators
    relay_out = relay_op(
  File "/WORK/Dev/tvm/python/tvm/relay/frontend/pytorch.py", line 750, in full
    return self.full_impl(data, fill_value, dtype)
  File "/WORK/Dev/tvm/python/tvm/relay/frontend/pytorch.py", line 671, in full_impl
    out = _op.full(_expr.const(fill_value, dtype=dtype), size, dtype=dtype)
  File "/WORK/Dev/tvm/python/tvm/relay/expr.py", line 517, in const
    raise ValueError("value has to be scalar or NDArray")
ValueError: value has to be scalar or NDArray

It works for me using this script:

from tvm import relay

import torch
from transformers import GPT2LMHeadModel

token_predictor = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

random_tokens = torch.randint(10000, (5,))
traced_token_predictor = torch.jit.trace(token_predictor, random_tokens)

inputs = [("dummy_input_name", (5,))]
mod, params = relay.frontend.from_pytorch(traced_token_predictor, inputs, default_dtype="int64")
print(mod)
1 Like

I updated tvm to the latest version and it works. Thanks a lot for your kind help :slight_smile:

Because gpt2 requires the input size to increase at each step, under the code above and static shape for current tvm (main branch), I can only do inference on a fixed sequence length. How to solvo this problem? I know relax may solve this problem.

It is also possible to import the model with dynamic shape in Relay. But the performance would be extremely poor.

It is also possible to import the model with dynamic shape in Relay.

Is there some demo code for this? I want have a try.

Another question is when will the next version be released with Relax?

This is an example for ONNX, and PT frontend doesn’t support dynamic input shape. But it’s not difficult to add such feature.

Thanks for your kind help. I will have a try. Another question is when will the next version be released with Relax?

checkout Introducing Web-LLM: Running large language model on web

Amazing! I will try it on the browser.

Hi @zhaoyang-star how will we generate the sentence after getting mods and params.Please help using tvm

when will tvm support dynamic shape and dynamic shape turing with gpu? @tqchen