Fail to convert Bert from pytorch(huggingface) to Relay Graph

I am trying to compile “bert-base-uncased” model via the pytorch frontend.

I follows the instruction of Exporting transformers models — transformers 4.7.0 documentation and get a torchscript traced model. Then I try to use relay.frontend.from_pytorch, it says

The Relay type checker is unable to show the following types match.
In particular dimension 0 conflicts: 512 does not match 768.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(768), float32]` does not match `Tensor[(512), float32]`

The full diagnosis is

Traceback (most recent call last):
  File "torchscript_compile.py", line 69, in <module>
    mod, params = relay.frontend.from_pytorch(script_module, input_infos)
  File "/home/yifanlu/TVM/tvm/python/tvm/relay/frontend/pytorch.py", line 3335, in from_pytorch
    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
  File "/home/yifanlu/TVM/tvm/python/tvm/relay/frontend/pytorch.py", line 2759, in convert_operators
    self.record_output_type(relay_out)
  File "/home/yifanlu/TVM/tvm/python/tvm/relay/frontend/pytorch.py", line 219, in record_output_type
    self.infer_type_with_prelude(output)
  File "/home/yifanlu/TVM/tvm/python/tvm/relay/frontend/pytorch.py", line 167, in infer_type_with_prelude
    body = self.infer_type(val, self.prelude.mod)
  File "/home/yifanlu/TVM/tvm/python/tvm/relay/frontend/pytorch.py", line 160, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/home/yifanlu/TVM/tvm/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/yifanlu/TVM/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm.error.DiagnosticError: Traceback (most recent call last):
  6: TVMFuncCall
  5: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM8::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM8::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  4: tvm::transform::Pass::operator()(tvm::IRModule) const
  3: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  0: tvm::DiagnosticContext::Render()
  File "/home/yifanlu/TVM/tvm/src/ir/diagnostic.cc", line 105
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

My TVM version is the latest one.

LLVM is 12.0.0

I am new to TVM, and want to know how to solve it. The full code is shown below

from transformers import BertModel, BertTokenizer, BertConfig
import torch
import tvm
from tvm import relay
from tvm.contrib.download import download_testdata


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenizing input text
sentence_a = "Who was Jim Henson ?"
sentence_b = "Jim Henson was a puppeteer."
tokenized_text = tokenizer(sentence_a,sentence_b,padding='max_length')

# Masking one of the input tokens
input_ids = tokenized_text['input_ids']
atten_mask = tokenized_text['attention_mask']
token_type_ids = tokenized_text['token_type_ids']

masked_index = 8
tokenized_text[masked_index] = tokenizer.convert_tokens_to_ids("[MASK]")

# Creating a dummy input
input_ids_tensor = torch.tensor([input_ids])
atten_mask_tensors = torch.tensor([atten_mask])
token_type_ids_tensors = torch.tensor([token_type_ids])

dummy_input = [input_ids_tensor, atten_mask_tensors,token_type_ids_tensors] 

config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
    num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)

# Instantiating the model
model = BertModel(config)

# The model needs to be in evaluation mode
model.eval()

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

# Creating the trace
traced_model = torch.jit.trace(model, dummy_input)
traced_model.eval()

# tvm part
script_module = traced_model
input_infos = [("input_ids",((1,512),"int")),("attention_mask",((1,512),"int")),("token_type_ids",((1,512),"int"))]
mod, params = relay.frontend.from_pytorch(script_module, input_infos)

Since this is not working, I also tried to convert through the onnx model. But I ran into another problem.

I would really appreciate it if you can help me out.

from transformers import BertModel, BertTokenizer, BertConfig
import torch
import tvm
import numpy as np
from tvm import relay
import onnx

model_path = "/home/yifanlu/TVM/TVM_Sample_Text/bert/onnx/bert-base-uncased.onnx"
onnx_model = onnx.load(model_path)

target = "llvm"
dev = tvm.cpu()

shape_dict = {"input_ids": (1,512), "attention_mask": (1,512), "token_type_ids":(1,512)}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
# succeed in converting to relay graph

with tvm.transform.PassContext(opt_level=1):
    intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target)

x1 = np.random.randint(0,2,(1,512))
x2 = np.random.randint(0,2,(1,512))
x3 = np.random.randint(0,2,(1,512))
dtype = "float32"
tvm_output = intrp.evaluate()(tvm.nd.array(x1),tvm.nd.array(x2),tvm.nd.array(x3), **params).numpy()

For the last line code, it raises an error

Traceback (most recent call last):
  File "onnx_compile_test.py", line 25, in <module>
    tvm_output = intrp.evaluate()(tvm.nd.array(x1),tvm.nd.array(x2),tvm.nd.array(x3), **params).numpy()
  File "/home/yifanlu/TVM/tvm/python/tvm/relay/backend/interpreter.py", line 172, in evaluate
    return self._make_executor()
  File "/home/yifanlu/TVM/tvm/python/tvm/relay/build_module.py", line 482, in _make_executor
    "Graph Executor only supports static graphs, got output type", ret_type
ValueError: ('Graph Executor only supports static graphs, got output type', TupleTypeNode([TensorType([?, 512, 768], float32), TensorType([?, 768], float32)]))

I think I have determined the input size with batchsize=1, but it does happen. I modify “graph” to “debug” and “vm”, but that leads to other errors.

What’s your transformer version? AFAIK, transformer 3.5 was the version when the community was testing BERT conversion. I personally use transformer 4.3 in local.

Also if you are using PyTorch 1.9, it is not expected to work. See [PyTorch] Unable to import a simple torch.nn.Linear to Relay

Yes!

I downgrade the pytorch verison to 1.8.0, then it works. Thank you!

Thank you for your reply!

transformer version is all right. The problem lies in pytorch version and I have solved it.

1 Like