I recently found that the documentation of TVM is not very complete. When I searched for some optimization documents at night, the links were all from a few years ago, and some were even invalid. This shows that there were many application documents for TVM at the beginning, but they were deleted or not maintained. The examples provided by the official website are too simple. I am looking for an optimization document for Bert dynamic shape (ONNX) recently, but I haven’t found it yet. I haven’t run it through based on the v0.20 version myself. I am very sad and don’t know how to fix it.
Can you provide the Bert error log
Since TVM did not provide examples of dynamic shapes, I have only modified other examples to try, but none have been successful so far.
import onnx
import tvm
from tvm import relax, meta_schedule as ms
from tvm.relax.frontend.onnx import from_onnx
import numpy as np
from collections import OrderedDict
from tvm import meta_schedule as ms
def measure_tvm(ex, device, input_dict, num_runs=5):
vm = relax.VirtualMachine(ex, device)
print("---measure_tvm begin")
input_list = []
for v in input_dict.values():
print("npy shape: ", v.shape)
array = tvm.nd.array(v, device)
print("array: ", array)
input_list.append(array)
print("input_list: ", input_list)
# return
print("begin warm up")
output = vm["main"](*input_list)
base = output[0].numpy()
print(base)
time_list = []
ret_list = []
print("--measure loop, num_runs=", num_runs)
for _ in range(num_runs):
print("begin--------------------")
start_time = time.time()
output = vm["main"](*input_list)
end_time = time.time()
cmp = output[0].numpy()
# print(cmp)
ret = np.allclose(base, cmp, atol=1e-5)
ret_list.append(ret)
delta = end_time-start_time
print("one consume: {} second".format(delta))
time_list.append(delta)
print("---measure end")
print("time_list: ", time_list)
print("diff ret: ", ret_list)
consume_avg = sum(time_list) / num_runs
print("tvm consume, avg", consume_avg*1000, " ms")
def tune_2(mod, target, input_dict, work_dir, trias=8000, num_runs=1):
device = tvm.device(target, 0)
tvm_target = tvm.target.Target(target)
print("device: ", device)
print("target: ", tvm_target)
print("Convert operators for inference mode.")
mod = relax.transform.DecomposeOpsForInference()(mod)
# Legalize any relax ops into tensorir.
print("Legalize any relax ops into tensorir")
mod = relax.transform.LegalizeOps()(mod)
mod['main'].show()
with tvm_target:
seq = tvm.ir.transform.Sequential(
[
relax.get_pipeline("zero"),
relax.transform.MetaScheduleTuneTIR(work_dir=work_dir, max_trials_global=trias),
relax.transform.MetaScheduleApplyDatabase(work_dir=work_dir),
]
)
mod = seq(mod)
ex = relax.build(mod, target=tvm_target)
print("measure_tvm cpu")
measure_tvm(ex, device, input_dict, num_runs)
# begin run
# load bert modle(batch和seq_len为动态)
onnx_model = onnx.load("/bert.onnx")
# convert
mod = from_onnx(onnx_model)
# prepare input
input_dict = OrderedDict()
input_dict['input_ids'] = np.random.randint(0, batch_size, size=(batch_size, seq_len), dtype="int64")
input_dict['att_mask'] = np.random.randint(0, 1, size=(batch_size, seq_len), dtype="int64")
input_dict['token_type_ids'] = np.random.rand(batch_size, seq_len).astype("int64")
# begin tune and measure
tune_2(mod, "cuda", input_dict, "./bert_dynamic_shape_tune_cuda", num_runs=10)
I found there have no Task in process
2025-06-09 23:02:43 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
2025-06-09 23:02:43 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2025-06-09 23:02:46 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0
2025-06-09 23:02:46 [INFO] [task_scheduler.cc:320]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
---------------------------------------------------------------------------------------------------
0 | main | 1 | 1 | N/A | N/A | N/A | 0 | Y
---------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0
for a while, raise exception
Did you forget to bind?
Variable `lv10` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `T_add` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `lv8` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `lv7` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
File "/tvm_home/tvm-0.20/src/tir/analysis/verify_memory.cc", line 203