I am new to TVM and I am trying to profile the operators in different LLMs and do a comparative analysis for a paper I am working on. However I am struggling with doing so, before I was able to profile the operators in DL models(eg Resnet) using relax_vm.profile(). However I am not able to do the same in the LLMs from MLC-LLM. I am importing the LLM models based on the tutorial. However the function in particular- create_paged_kv_cache is not parsed when converting to relax_VM and gives the error: InternalError: Check failed: (func.defined()) is false: Error: Cannot find PackedFunc mlc.create_paged_kv_cache_generic in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in global Relax functions of the VM executable
from mlc_llm.model.gpt2 import gpt2_model
import tvm
from tvm import relax
from tvm import dlight as dl
from transformers import AutoTokenizer, GPT2LMHeadModel
from tvm.runtime import relax_vm
from tvm.relax.frontend import nn
config_dict = {
"architectures": ["GPT2LMHeadModel"],
"bos_token_id": 50256,
"eos_token_id": 50256,
"hidden_act": "gelu_new",
"n_ctx": 1024,
"n_embd": 768,
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
"layer_norm_epsilon": 1e-05,
"scale_attn_by_inverse_layer_idx": False,
"vocab_size": 50257,
}
prompt = "My favorite music is "
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer(prompt, return_tensors="pt")
print(inputs)
config = gpt2_model.GPT2Config.from_dict(config_dict)
model = gpt2_model.GPT2LMHeadModel(config)
mod, named_params = model.export_tvm(
spec= model.get_default_spec()
)
mod_from_torch, params_from_torch = tvm.relax.frontend.detach_params(mod)
mod = mod_from_torch
mod = relax.transform.LegalizeOps()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)
print(mod.get_global_vars())
with tvm.target.Target("cuda"):
gpu_mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(mod)
target = tvm.target.Target("cuda")
exec = relax.build(mod, target=target)
dev = tvm.device(str(target.kind), 0)
vm = relax.VirtualMachine(exec, dev)