[iOS- Auto Tvm] Auto Tuning was not happening for iOS metal target

Hi Experts, @tqchen, @kazum

I tried to auto tune the sample Resnet model on iOS metal target , I went through the ([Auto-TVM] How to Auto tune the model on iOS device). while trying to tune the model every time I see "Current/Best: 0.00/ 0.00 GFLOPS " in the logs and showing Test Succeeded. I tried to tune for a while (ntrail=100) and compared the results of before and after tuning the resnet model. I didn’t see any difference in optimization and while picking best records form cache file (autotvm.record.pick_best(tmp_log_file, log_filename) log_filename is having the zero bytes (it did not picked any logs from the tuned log file). please find results the before and after tuned resnet model.

With Tuning Resnet model : Mean inference Time & std dev = 54.05 ms & 1.27 ms

Without Tuning Resnet model : Mean inference Time & std dev = 51.07 ms & 0.14 ms

Code :

target = 'metal'
proxy_port = 9090
key = "iphone"
arch = "arm64"
sdk = "iphoneos"
target_host = "llvm -mtriple=%s-apple-darwin" % arch

@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src):
    return xcode.compile_metal(src, sdk=sdk)

#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.log" % network
dtype = 'float32'

autotvm.measure.measure_methods.check_remote = lambda *args: True

def fcompile(*args):
    from tvm.contrib import xcode
    xcode.create_dylib(*args, arch=arch, sdk=sdk)
    path = args[0]
    xcode.codesign(path)
   #  xcode.popen_test_rpc(proxy_host, proxy_port, key,
   #                  destination=destination,
   #                  libs=[path])

fcompile.output_format = "dylib"

tuning_option = {
    'log_filename': log_file,
    'tuner': 'xgb',
    'early_stopping': None,

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(
            n_parallel=1,
            build_func=fcompile,
            timeout=60
        ),
        runner=autotvm.RPCRunner(
            key, host='127.0.0.1', port=9190,
            number=20, repeat=3, timeout=60, min_repeat_ms=150)
    ),
}
def tune_tasks(tasks,
               measure_option,
               tuner='xgb',
               n_trial=100,
               early_stopping=None,
               log_filename='tuning.log',
               use_transfer_learning=False):
    # create tmp log file
    tmp_log_file = log_filename + ".tmp"
    if os.path.exists(tmp_log_file):
        os.remove(tmp_log_file)

    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " %(i+1, len(tasks))

        # create tuner
        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(tsk, loss_type='rank')
        elif tuner == 'ga':
            tuner_obj = GATuner(tsk, pop_size=100)
        elif tuner == 'random':
            tuner_obj = RandomTuner(tsk)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(tsk)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        if use_transfer_learning:
            if os.path.isfile(tmp_log_file):
                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

        # do tuning
        tsk_trial = min(n_trial, len(tsk.config_space))
        tuner_obj.tune(n_trial=tsk_trial,
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                           autotvm.callback.log_to_file(tmp_log_file)
                       ])

    # pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_filename)
    os.remove(tmp_log_file)


########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.

def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, input_shape, out_shape = get_network(network, batch_size=1)

    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                              params=params,
                                              target_host=target_host,
                                              ops=(relay.op.get("nn.conv2d"),))

    # run tuning tasks
    print("Tuning...")
    tune_tasks(tasks, **tuning_opt) 
    with autotvm.apply_history_best(log_file):
    print("Compile...")
    with tvm.transform.PassContext(opt_level=3):
        graph, lib, params = relay.build_module.build(
            mod, target=target, params=params)

    # export library
    path_dso = "tuned_deploy.dylib"
    lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk)
    xcode.codesign(path_dso)

    # Evaluate inference cost on tuned lib
    xcode.popen_test_rpc(proxy_host, proxy_port, device_key, destination=destination, libs=[path_dso])

    remote = autotvm.measure.request_remote(device_key, '0.0.0.0', 9190,
                                            timeout=10000)

    # Upload not needed for ios because dylib is built into app
    # remote.upload(path_dso)

    rlib = remote.load_module(path_dso)

    ctx = remote.metal(0)
    
    module = runtime.create(graph, rlib, ctx)
    data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
    module.set_input('data', data_tvm)
    module.set_input(**params)

    # evaluate
    print("Evaluate inference time cost...")
    ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=20)
    prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
    print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
          (np.mean(prof_res), np.std(prof_res))

Any idea what is happening and how to fix the above issue.

can you please help me to understand and fix the above issue.

Thanks

Hi Team,

Any suggestion on the above issue .

Thanks

@Dileep Can you run tests/ios_rpc_test.py successfully? Your tuning script can run against the llvm target?

Can you try the following patch and check if something changes or not?

diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py
index a0a826abc..ed7e855c1 100644
--- a/python/tvm/autotvm/measure/local_executor.py
+++ b/python/tvm/autotvm/measure/local_executor.py
@@ -132,7 +132,7 @@ class LocalExecutor(executor.Executor):
         (e.g. cuda runtime, cudnn). Set this to False if you have used these runtime
         before submitting jobs.
     """
-    def __init__(self, timeout=None, do_fork=True):
+    def __init__(self, timeout=None, do_fork=False):
         self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
         self.do_fork = do_fork
 
@@ -143,7 +143,11 @@ class LocalExecutor(executor.Executor):
 
     def submit(self, func, *args, **kwargs):
         if not self.do_fork:
-            return LocalFutureNoFork(func(*args, **kwargs))
+            try:
+                res = func(*args, **kwargs)
+            except Exception as e:
+                res = e
+            return LocalFutureNoFork(res)
 
         queue = Queue(2)  # Size of 2 to avoid a race condition with size 1.
         process = Process(target=call_with_timeout,

hi @kazum, Sorry for the late response.

Can you run tests/ios_rpc_test.py successfully? --> Yes I’m able to run the ios_rpc_test.py and got the results for both iPhone CPU and Metal target.

Your tuning script can run against the llvm target? --> Yes, I’m able to tune the model on llvm target and I can see some change in the GFLOPS.

Tuning options for llvm target Code :

target = "llvm"

batch_size = 1
dtype = "float32"
model_name = "resnet-18"
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name

# Set the input name of the graph
# For ONNX models, it is typically "0".
input_name = "data"
num_threads = 1
os.environ["TVM_NUM_THREADS"] = str(num_threads)

tuning_option = {
    'log_filename': log_file,
    'tuner': 'random',
    'early_stopping': None,
    'n_trial': 100,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(
            number=1, repeat=10, min_repeat_ms=0, enable_cpu_cache_flush=True
        ),
    ),

Output :

[Task  1/12]  Current/Best:   10.13/  19.01 GFLOPS | Progress: (100/100) | 354.49 sCannot connect to tracker ('0.0.0.0', 9000), retry in 5 secs...
 Done.
[Task  2/12]  Current/Best:    4.21/  20.30 GFLOPS | Progress: (100/100) | 300.69 sCannot connect to tracker ('0.0.0.0', 9002), retry in 5 secs...
 Done.
[Task  3/12]  Current/Best:    9.94/  19.52 GFLOPS | Progress: (100/100) | 282.74 sCannot connect to tracker ('0.0.0.0', 9000), retry in 5 secs...
 Done.

Can you try the following patch and check if something changes or not? --> I’m working on this, I will update if any changes in the GFLOPS.

Please let me know if you have any suggestions.

Thanks,