Hello @echuraev ! I have solved an issue, thank you very much! Now I’m working on tuning the iOS model and I’m having next error:
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'default_module_loader.<locals>.default_module_loader_mgr'
And here’s my script for tuning the model:
import os
import numpy as np
import tvm
from tvm import relay, autotvm
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.utils import tempdir
import tvm.contrib.graph_executor as runtime
from tvm.contrib import xcode
from fan_model import FAN
from utils_inference_2 import *
import shutil
def get_model():
model_folder_path = "path-to-model
model_name = "model_name"
# detector = MTCNN()
model = FAN(2)
model_path = "pytorch-model-path"
checkpoint = torch.load(model_path, map_location='cpu')['state_dict']
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint)
model = model.eval()
input_shape = [1, 3, 256, 256]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
input_name = "input"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params, (1, 3, 256, 256)
def fcompile(*args):
xcode.create_dylib(*args, arch=arch, sdk=sdk)
path = args[0]
xcode.codesign(path)
xcode.popen_test_rpc(proxy_host, proxy_port, device_key, destination=destination, libs=[path])
fcompile.output_format = "dylib"
proxy_host = "192.168.1.4"
proxy_port = 9090
device_key = "iphone"
destination = "platform=iOS,id=phone-id"
# Change target configuration, this is setting for iphone6s
# arch = "x86_64"
# sdk = "iphonesimulator"
arch = "arm64"
sdk = "iphoneos"
target = "metal"
target_host = "llvm -mtriple=%s-apple-darwin" % arch
model_name = "model_name"
lib_path = "path-to-tvm-folder"
log_file = "model.log"
tuning_option = {
"log_filename": log_file,
"tuner": "xgb",
"n_trial": 1000,
"early_stopping": 450,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(
n_parallel=1,
build_func=fcompile,
timeout=60
),
runner=autotvm.RPCRunner(
device_key,
host='127.0.0.1', # I'm not sure. Here might be an actual IP address of proxy/host machine
port=9190,
number=20, repeat=3, timeout=60, min_repeat_ms=150)
),
}
# You can skip the implementation of this function for this tutorial.
def tune_tasks(
tasks,
measure_option,
tuner="xgb",
n_trial=1000,
early_stopping=None,
log_filename="tuning.log",
use_transfer_learning=True,
):
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
# create tuner
if tuner == "xgb" or tuner == "xgb-rank":
tuner_obj = XGBTuner(tsk, loss_type="rank")
elif tuner == "ga":
tuner_obj = GATuner(tsk, pop_size=50)
elif tuner == "random":
tuner_obj = RandomTuner(tsk)
elif tuner == "gridsearch":
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
tsk_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(
n_trial=tsk_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file),
],
)
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
def tune_and_evaluate(tuning_opt):
# extract workloads from relay program
print("Extract tasks...")
mod, params, input_shape = get_model()
tasks = autotvm.task.extract_from_program(
mod["main"],
target=target,
params=params,
ops=(relay.op.get("nn.conv2d"),),
)
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with tvm.transform.PassContext(opt_level=3):
lib = relay.build_module.build(mod, target=target, params=params)
# export library
tmp = tempdir()
if os.path.exists(lib_path) and os.path.isdir(lib_path):
shutil.rmtree(lib_path)
os.mkdir(lib_path)
path_dylib = lib_path + "/" + model_name + ".dylib"
lib.export_library(path_dylib, xcode.create_dylib, arch=arch, sdk=sdk)
# upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(device_key, "127.0.0.1", 9190, timeout=10000)
remote.upload(tmp.relpath(path_dylib))
rlib = remote.load_module(path_dylib)
# FINISH IT LATER!
# # upload parameters to device
# dev = remote.device(str(target), 0)
# module = runtime.GraphModule(rlib["default"](dev))
# data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
# module.set_input("data", data_tvm)
#
# # evaluate
# print("Evaluate inference time cost...")
# ftimer = module.module.time_evaluator("run", dev, number=1, repeat=30)
# prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
# print(
# "Mean inference time (std dev): %.2f ms (%.2f ms)"
# % (np.mean(prof_res), np.std(prof_res))
# )
if __name__ == "__main__":
tune_and_evaluate(tuning_option)
Please, let me know if I need to create a separate thread in order to discuss my new issue. Thank you very much!