I try to add GPT2 into metaschedule relay_workload at https://github.com/apache/tvm/blob/main/python/tvm/meta_schedule/testing/relay_workload.py
I try to imitate the codes “bert”, and here is the source code:
elif name == “gpt2”: os.environ[“TOKENIZERS_PARALLELISM”] = “false” # pip3 install transformers==3.5 torch==1.7 import torch # type: ignore import transformers # type: ignore
assert layout is None
config_dict = {
"gpt2": transformers.GPT2Config(
return_dict=False,
),
}
configuration = config_dict[name]
model = transformers.GPT2Model(configuration)
#model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', torchscript=True)
input_name = "input_ids"
input_dtype = "int64"
a = torch.randint(10000, input_shape) # pylint: disable=no-member
model.eval()
scripted_model = torch.jit.trace(model, [a], strict=False)
#print(type(scripted_model))
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
print(mod)
mod = relay.transform.FastMath()(mod)
mod = relay.transform.CombineParallelBatchMatmul()(mod)
inputs = (input_name, input_shape, input_dtype)
I met this error:
Does anyone know how to solve this problem?
Many thanks