Hi everyone, I was trying to compile the gpt2 model using TVM, however I encountered the problem while converting the PyTorch model into Relay IR and search for many days but still couldn’t solve it , so I am wondering if anyone knows how to fix this problem?
Thank you very much
Here are my packages and versions
tvm = 0.10.0
torch = 2.0.1
transformers = 4.33.2
This is my code , the error "AttributeError: ‘GPT2LMHeadModel’ object has no attribute ‘graph’ " was raised at the last line
import tvm
from tvm.contrib import graph_runtime
from tvm.relay.op.contrib import get_pattern_table
from transformers import AutoModel, AutoTokenizer, MobileBertTokenizer, AutoModelForCausalLM
import numpy as np
# sphinx_gallery_start_ignore
from tvm import testing
target = tvm.target.Target('c')
dev = tvm.cpu(0)
model_path = './model/traced_tiny-gpt2.pt'
data_type = 'float32' # input's data type
result = './tvm_generated_files/'
sequence = "Hello world!"
# Load a pretrained PyTorch model
# -------------------------------
import torch
from torchinfo import summary
model = torch.jit.load(model_path)
model = model.eval()
# Set input size
tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2", torchscript=True)
tokens=tokenizer('The cat is on the table.', return_tensors='pt')['input_ids']
verify = trace_model(tokens)[0].shape
# Convert the PyTorch model into Relay IR
inputs = [("dummy_input_name", (5,))]
from tvm import relay
mod, params = relay.frontend.from_pytorch(model, inputs )
This is the whole error report
File "compile_gpt2.py", line 59, in <module>
mod, params = relay.frontend.from_pytorch(model, inputs )
File "/root/tvm/python/tvm/relay/frontend/pytorch.py", line 4507, in from_pytorch
graph = script_module.graph.copy()
File "/root/anaconda3/envs/python38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GPT2LMHeadModel' object has no attribute 'graph' ```