Issue: Converting model from pytorch to relay model

Hello TVM developers and community,

I am trying to convert the Transformer-like models such as BERT from different platforms (Tensorflow or PyTorch) to relay models.

For TensorFlow model, I was able to convert them into relay models successfully by referring to this tutorial: Deploy a Hugging Face Pruned Model on CPU — tvm 0.8.dev0 documentation (apache.org)

However, for some models like SqueezeBert. Huggingface only provides PyTorch implementation. I need to convert from pytorch to relay model.

Here are the codes I use (which is very similar to tensorflow tutorial):

model = transformers.SqueezeBertForSequenceClassification(config)
shape_dict = {"input_1": (batch_size, s)}
traced_script_module = torch.jit.trace(model, np_input, strict=False)
mod, params = relay.frontend.from_pytorch(traced_script_module, input_infos=shape_dict)

However, “Error: The following operators are not implemented” would show up:

Traceback (most recent call last):
  File "TVMBertModel.py", line 187, in <module>
    mod, params, shape_dict, test_input = pytorch_get_tvm_model(BERT_type, config, sl, bs, np_input)
  File "TVMBertModel.py", line 85, in pytorch_get_tvm_model
    mod, params = relay.frontend.from_pytorch(
  File "/home/hungyangchang/.local/lib/python3.8/site-packages/tvm-0.9.dev32+gecd8a9ce3-py3.8-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 3858, in from_pytorch
    converter.report_missing_conversion(op_names)
  File "/home/hungyangchang/.local/lib/python3.8/site-packages/tvm-0.9.dev32+gecd8a9ce3-py3.8-linux-x86_64.egg/tvm/relay/frontend/pytorch.py", line 3091, in report_missing_conversion
    raise NotImplementedError(msg)
NotImplementedError: The following operators are not implemented: ['prim::DictConstruct']
free(): invalid pointer
Aborted (core dumped)

More info: I have read similar question such as [Frontend][Pytorch] Convert RCNN model from Torch Vision Model - Apache TVM Discuss, but I still don’t know how to solve it.

Any thoughts are welcomed.

I would suggest looking into converting the model → onnx → relay if possible. The onnx frontend is much more mature.

1 Like

Be careful with making such claims :slightly_smiling_face: Actually PT frontend is fairly good and I can generally recommend it for PT users.

@popojames You are probably using torch.jit.script, since dict construction has nothing to do with jit. You should use torch.jit.trace.

2 Likes

Hello, @AndrewZhaoLuo @masahi Thanks for your answer.

@AndrewZhaoLuo Yes, I can definitely try to converting the model → onnx → relay. But I still wanna try on Pytorch for now.

@masahi I have used “torch.jit.trace” to produce trace model, and it looks normal:

SqueezeBertForSequenceClassification(
  original_name=SqueezeBertForSequenceClassification
  (transformer): SqueezeBertModel(
    original_name=SqueezeBertModel
    (embeddings): SqueezeBertEmbeddings(
      original_name=SqueezeBertEmbeddings
      (word_embeddings): Embedding(original_name=Embedding)
      (position_embeddings): Embedding(original_name=Embedding)
      (token_type_embeddings): Embedding(original_name=Embedding)
      (LayerNorm): LayerNorm(original_name=LayerNorm)
      (dropout): Dropout(original_name=Dropout)
    ).......... 

However, as I shown above, when I try to use

mod, params = relay.frontend.from_pytorch(traced_script_module, input_infos=shape_dict)

The following error: operators are not implemented: [‘prim::DictConstruct’] still shows up

Note: I am using newest tvm version: tvm 0.9.dev32+gecd8a9ce3

Update:

According to: PyTorch convert function for op ‘dictconstruct’ not implemented · Issue #1157 · apple/coremltools (github.com)

After changing my code from

model = transformers.SqueezeBertForSequenceClassification(config)

into

model=transformers.SqueezeBertForSequenceClassification.from_pretrained(‘squeezebert/squeezebert-uncased’, return_dict=False)

I was able to convert it successfully!

4 Likes