Fail to convert latest HuggingFace Bert model to Relay

Hi, I’m following the blogs of Bridging PyTorch and TVM and Export to TorchScript to convert a standard Bert model to Relay.

The codes are just copies from these two blogs:

from transformers import BertModel, BertTokenizer, BertConfig
import torch
import tvm

enc = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(
    vocab_size_or_config_json_file=32000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    torchscript=True,
)

# Instantiating the model
model = BertModel(config)

# The model needs to be in evaluation mode
model.eval()

# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")

shape_list = [(i.debugName().split('.')[0], i.type().sizes()) for i in  list(traced_model.graph.inputs())[1:]]

mod_bert, params_bert = tvm.relay.frontend.pytorch.from_pytorch(traced_model,
                        shape_list, default_dtype="float32")

Then I got the following errors:

Traceback (most recent call last):
  File "bert_test.py", line 48, in <module>
    mod_bert, params_bert = tvm.relay.frontend.pytorch.from_pytorch(traced_model,
  File "/root/miniconda3/envs/py38/lib/python3.8/site-packages/tvm/relay/frontend/pytorch.py", line 4554, in from_pytorch
    outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
  File "/root/miniconda3/envs/py38/lib/python3.8/site-packages/tvm/relay/frontend/pytorch.py", line 3928, in convert_operators
    relay_out = relay_op(
  File "/root/miniconda3/envs/py38/lib/python3.8/site-packages/tvm/relay/frontend/pytorch.py", line 1638, in linear
    mm_out = self.matmul(
  File "/root/miniconda3/envs/py38/lib/python3.8/site-packages/tvm/relay/frontend/pytorch.py", line 1858, in matmul
    out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
  File "/root/miniconda3/envs/py38/lib/python3.8/site-packages/tvm/relay/op/transform.py", line 309, in reshape
    tempshape.append(int(shape))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'Any'

My environments are:

HuggingFace transformers: 4.25.1
TVM: apache-tvm-cu116-cu116 from https://tlcpack.ai/wheels
Torch: 1.13.0+cu116
Python: 3.8

See https://github.com/apache/tvm/issues/13664

Thank you for reply.

I’ve tried the codes in https://github.com/masahi/torchscript-to-tvm/blob/master/transformers/bert_clean.py, but the error is same. It seems that this issue is also not addressed completely.

Like I said in the github issue, HF BERT has always been working fine with TVM. Can you try the latest source build of TVM instead of the released build? Because that’s what all developers are using.

I’ve tried the latest source build and it works. Thank you!

Hi, @masahi, I have a question about HuggingFace. Have you tried using tvm to run and export HuggingFace gpt2 model?

I haven’t tried, but if you hit errors and have a repro script, I can take a look.