Tune the model in iOS

Hello everyone! I have successfully built model for iOS and runtime for iOS (Thanks to your amazing support!). But right now I’m having problems while trying to tune my model.

So my tuning process is listed below:

  1. Launch the rpc.tracker via terminal command python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190

  2. Launch rpc.proxy via terminal command python -m tvm.exec.rpc_proxy --host 0.0.0.0 --tracker 0.0.0.0:9190. So far so good

  3. Launch “iOS_rpc” application and connect to the rpc.proxy using presented UI. Connection is successful

  4. Launch my tuning script that was built using tutorial from TVM docs (Auto-tuning a Convolutional Network for Mobile GPU — tvm 0.8.dev0 documentation):

    import os

    import numpy as np

    import tvm from tvm import relay, autotvm import tvm.relay.testing

    from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner from tvm.contrib.utils import tempdir import tvm.contrib.graph_executor as runtime from tvm.contrib import xcode

    from fan_model import FAN from utils_inference_2 import *

    import shutil

    def get_model(): model_folder_path = “path_to_model” model_name = “model_name”

     model = FAN(2)
     model_path = "path_to_pytorch_model"
     checkpoint = torch.load(model_path, map_location='cpu')['state_dict']
     model = torch.nn.DataParallel(model)
     model.load_state_dict(checkpoint)
    
     model = model.eval()
    
     input_shape = [1, 3, 256, 256]
     input_data = torch.randn(input_shape)
     scripted_model = torch.jit.trace(model, input_data).eval()
    
     input_name = "input"
     shape_list = [(input_name, input_shape)]
     mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
    
     return mod, params, (1, 3, 256, 256)
    

    def fcompile(*args): xcode.create_dylib(*args, arch=arch, sdk=sdk) path = args[0] xcode.codesign(path) xcode.popen_test_rpc(proxy_host, proxy_port, device_key, destination=destination, libs=[path])

    fcompile.output_format = “dylib”

    proxy_host = “192.168.1.4” proxy_port = 9090 device_key = “iphone” destination = “platform=iOS,id=device-id”

    Change target configuration, this is setting for iphone6s

    arch = “x86_64”

    sdk = “iphonesimulator”

    arch = “arm64” sdk = “iphoneos” target_host = “llvm -mtriple=%s-apple-darwin” % arch

    tgt = tvm.target.Target(“metal -device=gpu”, host=target_host)

    model_name = “model_name” lib_path = “path_where_to_store_tuning_results”

    log_file = “model.log” tuning_option = { “log_filename”: log_file, “tuner”: “xgb”, “n_trial”: 1000, “early_stopping”: 450, ‘measure_option’: autotvm.measure_option( builder=autotvm.LocalBuilder( n_parallel=1, build_func=fcompile, timeout=60 ), runner=autotvm.RPCRunner( device_key, host=‘127.0.0.1’, # I’m not sure. Here might be an actual IP address of proxy/host machine port=9190, number=20, repeat=3, timeout=60, min_repeat_ms=150) ), }

    You can skip the implementation of this function for this tutorial.

    def tune_tasks( tasks, measure_option, tuner=“xgb”, n_trial=1000, early_stopping=None, log_filename=“tuning.log”, use_transfer_learning=True, ): # create tmp log file tmp_log_file = log_filename + “.tmp” if os.path.exists(tmp_log_file): os.remove(tmp_log_file)

     for i, tsk in enumerate(reversed(tasks)):
         prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
    
         # create tuner
         if tuner == "xgb" or tuner == "xgb-rank":
             tuner_obj = XGBTuner(tsk, loss_type="rank")
         elif tuner == "ga":
             tuner_obj = GATuner(tsk, pop_size=50)
         elif tuner == "random":
             tuner_obj = RandomTuner(tsk)
         elif tuner == "gridsearch":
             tuner_obj = GridSearchTuner(tsk)
         else:
             raise ValueError("Invalid tuner: " + tuner)
    
         if use_transfer_learning:
             if os.path.isfile(tmp_log_file):
                 tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
    
         # do tuning
         tsk_trial = min(n_trial, len(tsk.config_space))
         tuner_obj.tune(
             n_trial=tsk_trial,
             early_stopping=early_stopping,
             measure_option=measure_option,
             callbacks=[
                 autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                 autotvm.callback.log_to_file(tmp_log_file),
             ],
         )
    
     # pick best records to a cache file
     autotvm.record.pick_best(tmp_log_file, log_filename)
     os.remove(tmp_log_file)
    

    def tune_and_evaluate(tuning_opt): # extract workloads from relay program print(“Extract tasks…”) mod, params, input_shape = get_model() tasks = autotvm.task.extract_from_program( mod[“main”], target=tgt, params=params, ops=(relay.op.get(“nn.conv2d”),), )

     # run tuning tasks
     print("Tuning...")
     tune_tasks(tasks, **tuning_opt)
    
     # compile kernels with history best records
     with autotvm.apply_history_best(log_file):
         print("Compile...")
         with tvm.transform.PassContext(opt_level=3):
             lib = relay.build_module.build(mod, target=target, params=params)
         # export library
         tmp = tempdir()
    
         if os.path.exists(lib_path) and os.path.isdir(lib_path):
             shutil.rmtree(lib_path)
    
         os.mkdir(lib_path)
    
         path_dylib = lib_path + "/" + model_name + ".dylib"
         lib.export_library(path_dylib, xcode.create_dylib, arch=arch, sdk=sdk)
    
         # upload module to device
         print("Upload...")
         remote = autotvm.measure.request_remote(device_key, "127.0.0.1", 9190, timeout=10000)
         remote.upload(tmp.relpath(path_dylib))
         rlib = remote.load_module(path_dylib)
    
         # FINISH IT LATER!
         # # upload parameters to device
         # dev = remote.device(str(target), 0)
         # module = runtime.GraphModule(rlib["default"](dev))
         # data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
         # module.set_input("data", data_tvm)
         #
         # # evaluate
         # print("Evaluate inference time cost...")
         # ftimer = module.module.time_evaluator("run", dev, number=1, repeat=30)
         # prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
         # print(
         #     "Mean inference time (std dev): %.2f ms (%.2f ms)"
         #     % (np.mean(prof_res), np.std(prof_res))
         # )
     if __name__ == "__main__":
         tune_and_evaluate(tuning_option)
    

When I launch the script, I have this log:

Tuning...
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (0/1000) | 0.00 s[14:06:09] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:09] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (1/1000) | 5.20 s[14:06:13] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:13] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (2/1000) | 8.88 s[14:06:17] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:17] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (3/1000) | 12.70 s[14:06:20] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:20] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (4/1000) | 16.55 s[14:06:24] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:24] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (5/1000) | 20.29 s[14:06:28] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:28] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (6/1000) | 24.47 s[14:06:32] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:32] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (7/1000) | 28.62 s[14:06:36] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:36] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (8/1000) | 32.69 s[14:06:40] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:40] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (9/1000) | 36.83 s[14:06:45] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:45] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (10/1000) | 41.11 s[14:06:49] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:49] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (11/1000) | 45.27 s[14:06:53] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:53] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (12/1000) | 49.66 s[14:06:57] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:06:57] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (13/1000) | 53.67 s[14:07:01] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:01] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (14/1000) | 57.94 s[14:07:06] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:06] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (15/1000) | 62.52 s[14:07:10] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:10] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (16/1000) | 66.57 s[14:07:14] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:14] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (17/1000) | 70.64 s[14:07:18] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:18] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (18/1000) | 74.94 s[14:07:23] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:23] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (19/1000) | 78.99 s[14:07:27] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:27] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
[Task  1/48]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (20/1000) | 82.87 s[14:07:31] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 0, name=AMD Radeon Pro 5300M
[14:07:31] /Users/superyevhen/Documents/Work/TVM/tvm/src/runtime/metal/metal_device_api.mm:152: Intializing Metal device 1, name=Intel(R) UHD Graphics 630

As you can see, my GFLOPS always stay 0.00 value. I did my own digging and I have noticed that all my tasks from the network have suffix “cuda”. Can it be a problem?

Thanks in advance!

Hi @L1onKing!

As in short “0.00/0.00” means that internal error happens during kernel measure but the real reason of trial failure was not printed. Definitely It looks like a usability issue of autotvm logger. It doesn’t report any error marker, just goes on.

Anyway I’ve tried your script and got absolutely the same log result. In the fact the real reason of error on my side was in missing function name in module:

RPCError: Error caught from RPC call: [16:53:12] /Users/apeskov/git/tvm/src/runtime/library_module.cc:49: Check failed: (entry_name != nullptr) is false: Symbol tvm_main is not presented I’ll continue investigation of this error a littel bit later.

But meantime could you please clarify reason of failure with your environment? Just print error message in python/tvm/autotvm/measure/measure_methods.py:598.

Also, I would like to note that recently there was introduced iOS rpc server without requirements to relaunch it via “xcode.popen_test_rpc” and I strongly recommend you to try it, it may solve a lot of your issues. For more details about using this new iOS rpc mechanics take a look into next doc section:

Hello @apeskov !

Thank you for your help! I did try to output the log at measure_methods.py but it looks like that my code doesn’t go in it. Although I did some modifications regarding Xcode.popen_test_rpc which I’m sure are not correct.

So now before moving any further, I have decided to build new ios_rpc. I did download custom_dso_loader successfully but now I’m having a linker error:

Undefined symbols for architecture arm64:
  "tvm::runtime::SplitKernels(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)", referenced from:

Have you encountered such issue before? I have checked this function, it is presented and CMake has a reference to the directory where it’s placed. So meanwhile I’m trying to solve it and make sure new ios_rpc is working and then continue with my tuning script.

Thank you again for your help, I really appreciate it!

UPDATE:

I have just fixed a linker issue by adding this line at the top of TVMRuntime.mm:

#include "../../../src/runtime/source_utils.cc"

Hi @L1onKing,

Thank you for such quick update! You are right that’s a new issue with building of iOS RPC App. I faced the same problem today as well. The patch leading to this “undef symbol” was merged only yesterday evening. Your fix is absolutely correct, you may use it till we fix it.

Currently I’m preparing a patch with refactoring iOS RPC app. It will exclude this kind of problem related with source including. Hope it will be finished this week.

Small update,

I’ve checked autotvm tuning with iOS metal backend one more time and I can confirm, it’s broken a little bit. The problem with kernel compilation step, it produce an empty module without export symbols. However auto_scheduler builder works well. I can recommend you to switch to auto_scheduler tuning, it should works without significant problems.

Example of using auto_scheduler with your code, put it into your script and modify tuning options if you would like:

from tvm import auto_scheduler


# No need to restart RPC server via xcode.popen_test_rpc
def fcompile(*args):
    xcode.create_dylib(*args, arch=arch, sdk=sdk)


fcompile.output_format = "dylib"


def tune_and_evaluate_auto_scheduler():
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, input_shape = get_model()
    tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target=target, target_host=target_host)

    print("Tuning...")
    measure_runner = auto_scheduler.RPCRunner(
        device_key,
        "0.0.0.0",
        9190,
        min_repeat_ms=300,
        timeout=30,
        n_parallel=1
    )
    builder = auto_scheduler.LocalBuilder(build_func=fcompile, n_parallel=8, timeout=10000)
    tune_option = auto_scheduler.TuningOptions(
        builder=builder,
        num_measure_trials=256,
        num_measures_per_round=32,
        runner=measure_runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        verbose=2
    )

    tuner = auto_scheduler.TaskScheduler(tasks)
    tuner.tune(tune_option)

    print("Compile...")
    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
            lib = relay.build(mod, target=target, target_host=target_host, params=params)
            lib.export_library("lib_ios_tuned.dylib", fcompile=xcode.fcompile, arch=arch, sdk=sdk)

The iOS RPC app should be built with “custom_loader” support as described in doc paper. Connect it to rpc_tracker through rpc_proxy as before.

By the way, could you point out the Python version you’re using?

I’m using Python 3.8.9 from homebrew.

Hi @apeskov!

For the time being I have such an error while trying to tune the model:

 File "/Users/superyevhen/Documents/Work/TVM/tvm/python/tvm/auto_scheduler/measure.py", line 1078, in _timed_rpc_run
    random_fill = remote.get_function("tvm.contrib.random.random_fill")

What I got from this error is that every inference stage that is being optimised is built as a separate .dylib and that’s where I need to integrate USE_RANDOM=ON option. Could you please give me an advise on how to fix this issue?

Just add one more include line:

#include “…/…/…/src/runtime/contrib/random/random.cc”

If you will faced with some other missed function, add corresponding “.cc” include line.

Hi @apeskov ! Got it, thank you for your help! Right now I’m trying to tune my model using scheduler as you suggested, but script does one pass for one task and then quits.

Python gives this kind of output:

/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 8 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py", line 1785, in stoptrace
    debugger.exiting()
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py", line 1471, in exiting
    sys.stdout.flush()
ValueError: I/O operation on closed file.

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

And during debugging I have found out that the code doesn’t run further than this line

measure_inputs, measure_results = self.search_policies[task_idx].continue_search_one_round(
            self.num_measures_per_round, self.measurer
        ) #that's where it breaks

I have found this SearchOneRound method and put a breakpoint there in the ios_rpc app, but no luck. Could you tell me have you dealt with this issue before? Thanks!

UPDATE:

During my digging I have found out that the problem might occur because of xgb.DMatrix missing implementation. That’s where Python code silently breaks:

xgb_model.py, line 418:
ret = xgb.DMatrix(np.array(x_flatten), y_flatten)

And I can not find DMatrix implementation anywhere in the python code. Could you tell me where can I find it?

UPDATE 2:

xgb_model.py doesn’t have xgb implementation. Comment this line:

# xgb = None

and add this line:

import xgboost as xgb

Finally I have managed to produce a tuned model using this script:

import os

import numpy as np

import tvm
from tvm import relay, autotvm
import tvm.relay.testing

from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.utils import tempdir
import tvm.contrib.graph_executor as runtime
from tvm.contrib import xcode

from tvm import auto_scheduler

from fan_model import FAN
from utils_inference_2 import *

import shutil

def get_model():
    model = FAN(2)
    model_path = "path-to-my-pytorch-model"
    checkpoint = torch.load(model_path, map_location='cpu')['state_dict']
    model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint)

    model = model.eval()

    input_shape = [1, 3, 256, 256]
    input_data = torch.randn(input_shape)
    scripted_model = torch.jit.trace(model, input_data).eval()

    input_name = "input"
    shape_list = [(input_name, input_shape)]
    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

    return mod, params, (1, 3, 256, 256)

def fcompile(*args):
    xcode.create_dylib(*args, arch=arch, sdk=sdk)

fcompile.output_format = "dylib"

proxy_host = "192.168.1.4"
proxy_port = 9090
device_key = "iphone"
destination = "platform=iOS,id=device_id"

# Change target configuration, this is setting for iphone6s
# arch = "x86_64"
# sdk = "iphonesimulator"
arch = "arm64"
sdk = "iphoneos"
target = "metal"
# target = "llvm -mtriple=%s-apple-darwin" % arch
target_host = "llvm -mtriple=%s-apple-darwin" % arch

model_name = "flt_209"
lib_path = "/Users/superyevhen/Documents/Work/TVM/tmp_resources/flt_model_tvm"

log_file = "facetrace_model.log"
tuning_option = {
    "log_filename": log_file,
    "tuner": "xgb",
    "n_trial": 1000,
    "early_stopping": 450,
    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(
            n_parallel=1,
            build_func=fcompile,
            timeout=60
        ),
        runner=autotvm.RPCRunner(
            device_key,
            host='127.0.0.1', # I'm not sure. Here might be an actual IP address of proxy/host machine
            port=9190,
            number=20, repeat=3, timeout=60, min_repeat_ms=150)
    ),
}

def tune_and_evaluate_auto_scheduler():
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, input_shape = get_model()
    tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target=target, target_host=target_host)

    print("Tuning...")
    measure_runner = auto_scheduler.RPCRunner(
        device_key,
        "0.0.0.0",
        9190,
        min_repeat_ms=300,
        timeout=30,
        n_parallel=1
    )
    builder = auto_scheduler.LocalBuilder(build_func=fcompile, n_parallel=8, timeout=10000)
    tune_option = auto_scheduler.TuningOptions(
        builder=builder,
        num_measure_trials=256,
        num_measures_per_round=32,
        runner=measure_runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        verbose=2
    )

    tuner = auto_scheduler.TaskScheduler(tasks)
    tuner.tune(tune_option)

    print("Compile...")
    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
            lib = relay.build(mod, target=target, target_host=target_host, params=params)
            lib.export_library("lib_ios_tuned.dylib", xcode.create_dylib, arch=arch, sdk=sdk)

if __name__ == "__main__":
    tune_and_evaluate_auto_scheduler()

Unfortunately, the produced tuned model has the exact same inference as it had before tuning (around 850-900 ms).

Right now I’m thinking that maybe ApplyHistoryBest method doesn’t work correctly so I’ll dig into that as much as I can. Meanwhile I’d very appreciate if someone could give me some advice. Thanks in advance!

What I can see is that model is not using device power to the full extent. Although my model is Metal-based, I can see only CPU usage which is about 15%. I do not see any noticeable engagement of the GPU component. Is there a chance to boost the model usage via runtime API? @apeskov

Hi @L1onKing,

Glad to know that you managed to solve all configuration problems!

The relay.build under ApplyHistoryBest context is print warnings in case if no proper statistic was found for kernel. Example:

-----------------------------------
fused_nn.conv2d_add_nn.relu
Cannot find tuned schedules for target=metal -keys=metal,gpu -max_num_threads=256, workload_key=["6279408267b8cb7afd087563948b64ae", 1, 960, 7, 7, 960, 1, 3, 3, 1, 960, 1, 1, 1, 960, 7, 7]. A fallback TOPI schedule is used, which may bring great performance regression or even compilation failure. Compute DAG info:
placeholder = PLACEHOLDER [1, 960, 7, 7]
PaddedInput(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 < 8)) && (i3 >= 1)) && (i3 < 8)), placeholder[i0, i1, (i2 - 1), (i3 - 1)], 0f)
placeholder = PLACEHOLDER [960, 1, 3, 3]
DepthwiseConv2d(b, c, i, j) += (PaddedInput[b, c, (i + di), (j + dj)]*placeholder[c, 0, di, dj])
placeholder = PLACEHOLDER [1, 960, 1, 1]
T_add(ax0, ax1, ax2, ax3) = (DepthwiseConv2d[ax0, ax1, ax2, ax3] + placeholder[ax0, ax1, 0, 0])
T_relu(ax0, ax1, ax2, ax3) = max(T_add[ax0, ax1, ax2, ax3], 0f)

If you do not see such messages then it means that your tuning statistic is applied correctly.

Other point is in fact that you should spend a relevant amount of tuning trials to achieve good performance. Sometime it may take several hours days. You may control it via tuning option attributes: num_measure_trials=4000 and num_measures_per_round=64. During tuning process it will report a table with estimated time per kernel. So you can control tuning process and interrupt when it will be good enough.

Thank you for this info! I will definitely take a look.

Although I get a feeling that my tuning should have some impact. When I run the model without tuning, I have an inference time on my iPhone about 900 ms.

Then I run my tuning script and at the end my Estimated Total Latency is 250 ms. Here’s the final log:

Estimated total latency: 247.746 ms Trials: 252 Used time : 1981 s Next ID: 3

But, when I deploy pre-tuned model on the device, inference doesn’t change at all. And what I have noticed, that device is not really doing any serious job running inference whatsoever. CPU load is about 10-15%, and although there’s no straight way to check GPU load via Xcode, I can see that it doesn’t really do any job whatsoever.

Do I need to build runtime differently? I have built it with DEBUG flags. Or maybe there’re some settings I should incorporate in order to boost inference process?

That’s very suspicious to have zero load of GPU reported by Xcode. You also may verify GPU workload submitting with Instruments profiling, it may shed light on current GPU related problems.

Unfortunately there is no any additional settings to speedup GPU scoring with Metal (at least I didn’t faced this). The performance is strongly defined by shader code generated during compilation of resulting dylib. You may double check that Metal is engaged into scoring process. Just printf in MetalWrappedFunc::operator(…), that is a place where kernel is submitted into command queue.

Regarding Debug/Release build. As I know this doesn’t significantly impact final performance. But definitely you may check it.

Regarding “estimated total latency” and real time discrepancy. Estimated time is just a sum of task coast, id doesn’t respect over execution overheads like data transferring and reordering. Such overheads may lead to discrepancy up to 2-3 times. In addition, TVM may perform a lazy kernel compilation during first inference. So I recommend you to exclude first inference time and measure average value of next 10-20 inference.

Unfortunately, I didn’t get any improvement using auto_scheduler so I have tried to use tuning algorithm presented in this link (Auto-tuning a Convolutional Network for Mobile GPU — tvm 0.8.dev0 documentation)

As we have established earlier, this algorithm doesn’t work for “Metal” target at this point, but it works for CPU. So I have decided to try it with CPU model just to get at least some improvement.

Unfortunately, after 20 hours of tuning my results got worse. Before model was running about 600 ms, and after tuning it’s running 650 ms.

To be honest, I’m a bit confused with the result and right now I don’t know in what direction should I be going in my research. Clearly, I’m doing something horribly wrong, but I don’t know what exactly. Any help would be appreciated. Thank you

Hello @apeskov! I have a few questions and I’d be happy if you can shed me some light here:)

  1. What’s the difference between auto_scheduler and autotvm? I see that with the latter we have a choice of picking some tuner (XBG, Random, GridSearch, etc), when auto_scheduler doesn’t provide those options.
  2. What’s the difference between different tuners? Is there an article where these tuners are compared and where I can find some suggestions about what tuner should I pick and when
  3. If using autotvm, is there a point to run one model through each tuner individually? Is it fair to assume that it will give me the best performance possible for my specific model?

Thank you for your help, I’m very grateful

Hi @L1onKing!

It’s a good question. Will try to explain in general words.

Both autoTVM and autoScheduler do the same thing from customer perspective. They try to define best program for each subpart of model (we use a term “task” for this subgraph) based on performance evaluation of differently compiled kernels on real device.

AutoTVM tuning approach requires a search space to iterate over all options of compilation presented in that space and pick the best version. By default, TVM provides a set of templates to generate search space for each type of task, so you no need to do that manually, but you have such ability. Usually, search space contains a huge number of positions (10^9 and more) and that’s impossible to finish that in reasonable period of time. In that case TVM suggest to use some advances tuner strategy to intellectually iterates through search space (like model based XGBoost tuner type). So the tuner option in context of autoTVM is a way to decrease a number trials on real device and speed up search process. You can read about differences of tuning strategies in this section of autoTVM manual: Optimizing Operators with Templates and AutoTVM — tvm 0.8.dev0 documentation

AutoScheduler tuning approach has no requirements for search space. It use a ML based technics to predict a performance of each particular task schedule and define a search space based on this predictions. You can treat autoScheduler as successor of autoTVM, which tries to eliminate a weaknesses of previous approach. AutoTVM doesn’t require a TVM developers/customers to provide manually written search space generator for each type of platform and primitive. You can read about that in this section of autoScheduler tutorial: Optimizing Operators with Auto-scheduling — tvm 0.8.dev0 documentation

And now answering your questions:

  1. AutoScheduler is based on XGB model. That is an essential part of autoScheduler tuning process. There are no sense to specify it outside. Instead of autoTVM, where you may choose a different technics of iteration through search space depending on you time resources.

  2. Some suggestions are in Optimizing Operators with Templates and AutoTVM — tvm 0.8.dev0 documentation

  3. I don’t think so. The search space will be the same for all tuner types. It’s better to spend a time to perform more trials with one tuner as mach as possible (ideally and unattainably to brute force search space totally). Suppose XBG tuner is a best option for convolution tuning.

Hope it will help you to understand TVM tuning a little better and reach your goals with Metal on iOS!

Hi @apeskov !

Thank you for such an amazing explanation. Now I understand TVM far far better now and I understand towards what direction I should move in my research/development.

If I have some questions, I will let you know. Thank you again!

With regards, Yevhen