[Auto-TVM] How to Auto tune the model on iOS device

Hi Experts,

I need to Auto-tune the ONNX model on iOS device. i went through the some of the tutorials from https://tvm.apache.org/docs/tutorials/index.html#auto-tuning . i’m able to tune the model on the cuda (NVIDIA-GPU)target, but some how my bad i’m not able auto-tune the model on iOS device ( there is no tutorial on iOS device). 1.can we able to Auto-tune the model on iOS device ? if yes can you please help me to how to tune the model on iOS device.

Thank you.

2 Likes

Hi Team,

Any suggestion on the above issue .

Thanks

cc @kazum who might have some experience. iOS requires a special RPC , you can find some instructions here https://github.com/apache/incubator-tvm/tree/master/apps/ios_rpc

Hi @kazum, @tqchen Thank you for the response. I followed the above link to connect the proxy on iOS device able to connect the RPC Proxy. but i went through some of the tutorials , before tune the model we need to start the RPC tracker (https://tvm.apache.org/docs/tutorials/autotvm/tune_relay_mobile_gpu.html) . i followed the same steps to connect the RPC tracker on iOS device but some how i’m not able to connect the RPC tracker on iOS device ([iOS-RPC] Not able to connect RPC tracker on iOS device).

if i need to tune the model on iOS device should i need to start the RPC tracker before tune the model? if yes How to connect the RPC tracker on iOS device .

@kazum any tutorials or scripts are available , how to auto tune the model on iOS device.

Thanks

@kazum, @tqchen any suggestion on the above issue.

Thanks

@Dileep Sorry for the late response. I’ve tried auto tuning on iOS device before and needed some hacks to make it work.

  • Pass a customized build function to LocalBuilder to compile your model with Xcode.
  • Modify autotvm.measure.measure_methods.check_remote to make it return True always. It is necessary because, with the iOS RPC workflow, the devices are not visible from tuner before compiling.
  • Run rpc_proxy on the host machine so that iOS device can connect to the tracker.
    python -m tvm.exec.rpc_proxy --host [HOST_IP] --tracker [TRACKER_IP]:9190
    

Here is a part of code. Hope this would be helpful.

autotvm.measure.measure_methods.check_remote = lambda *args: True

def fcompile(*args):
    from tvm.contrib import xcode
    xcode.create_dylib(*args, arch=arch, sdk=sdk)
    path = args[0]
    xcode.codesign(path)
    xcode.popen_test_rpc(proxy_host, proxy_port, key,
                     destination=destination,
                     libs=[path])

fcompile.output_format = "dylib"

tuning_option = {
    'log_filename': log_file,
    'tuner': 'random',
    'early_stopping': None,

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(
            n_parallel=1,
            build_func=fcompile,
            timeout=60
        ),
        runner=autotvm.RPCRunner(
            key, host='127.0.0.1', port=9190,
            number=20, repeat=3, timeout=60, min_repeat_ms=150)
    ),
}

Hi @kazum , thank you for the response.

  • I ran rpc_proxy on the host machine with below command

python -m tvm.exec.rpc_proxy --host [HOST_IP] --tracker [TRACKER_IP]:9190

HOST_IP = IP address of host machine , TRACKER_IP = IP address of remote device (iOS device IP)

  • Tried to connect the iOS device to this RPC via rpc_proxy app (Opened the app, set Address (ip address of host machine), Port (9090) and the key set to be “iphone”).

  • Able to connect the RPC proxy , but while connecting the proxy app i got the below error in host terminal

      INFO:root:Handler ready TCPSocketProxy:192.168.43.55:server:iphone
      INFO:root:Lost tracker connection: [Errno 61] Connection refused, try reconnect in 2 sec
    

i set TRACKER_IP as a IP address of the remote device ( ip address of the iOS device ), am i right ? For above issue I tried changing the network also, still the issue is exist . Can you please help me to fix the above issue.

Thanks,

No, the tracker should run on the host machine. I think you also need to start the tracker with the following command:

python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190```

Hi @kazum - Thank you for the previous suggestions, I am also looking at how to use autotvm to tune a model on iOS.

Below is a modified version of ‘tutorials/autotvm/tune_relay_arm.py’ that is based on your previous comment suggestion of adding a build_func, but something isn’t working quite right yet.

Tuning tasks are stuck at 0 GFLOPS and the tuning trials time out.

[Task 1/12] Current/Best: 0.00/ 0.00 GFLOPS | Progress: (0/100) | 0.00 s

If I skip tuning (remove ‘#tune_tasks(tasks, **tuning_opt’), it successfully builds and runs the untuned model and reports an inference result.

Any idea what step might be missing here?

Thank you!

  1. Assumption: you have a single macOS based host running the rpc proxy, tracker and xcode, with local network IP: 192.168.0.10

  2. Setup environment variables:

      export TVM_IOS_CODESIGN='Apple Development: <your@email.com> (<SIGNINGCODE>)'
      export TVM_IOS_RPC_ROOT=${TVM_HOME}/apps/ios_rpc
      export TVM_IOS_RPC_PROXY_HOST=192.168.0.10
      #export TVM_IOS_RPC_DESTINATION='platform=iOS Simulator,id=<simulator id>'
      export TVM_IOS_RPC_DESTINATION='platform=iOS,id=<ios device id>'
    
  3. Start the tracker

     python3 -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190 --no-fork
     INFO:RPCTracker:bind to 0.0.0.0:9190
    
  4. Start the rpc proxy and point it to the tracker

     python3 -m tvm.exec.rpc_proxy --host 0.0.0.0 --tracker 0.0.0.0:9190 --no-fork
     INFO:root:RPCProxy: client port bind to 0.0.0.0:9090
    
  5. Run tuning:

     cd ${TVM_HOME}/apps/ios_rpc
     python3 tests/tune_relay_ios.py
    

Code:

"""
apps/ios_rpc/tests/tune_relay_ios.py

Auto-tuning a convolutional network for iPhone CPU
===============================================

"""

import os
import numpy as np
import tvm
from tvm import te
from tvm import autotvm
from tvm import relay
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
from tvm.contrib import xcode

#################################################################
# Define network
# --------------
# First we need to define the network in relay frontend API.
# We can load some pre-defined network from :code:`relay.testing`.
# We can also load models from MXNet, ONNX and TensorFlow.

def get_network(name, batch_size):
    """Get the symbol definition and random weight of a network"""
    input_shape = (batch_size, 3, 224, 224)
    output_shape = (batch_size, 1000)

    if "resnet" in name:
        n_layer = int(name.split('-')[1])
        mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
    elif "vgg" in name:
        n_layer = int(name.split('-')[1])
        mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
    elif name == 'mobilenet':
        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
    elif name == 'squeezenet_v1.1':
        mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
    elif name == 'inception_v3':
        input_shape = (1, 3, 299, 299)
        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == 'mxnet':
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model
        block = get_model('resnet18_v1', pretrained=True)
        mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
        net = mod["main"]
        net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
        mod = tvm.IRModule.from_expr(net)
    else:
        raise ValueError("Unsupported network: " + name)

    return mod, params, input_shape, output_shape

#################################################################
# Start RPC Tracker
# ------------------
# python3 -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190 --no-fork
#
#  - Autotvm will use the tracker to orchestrate tuning test runs.
# 
# Start RPC Proxy
# python3 -m tvm.exec.rpc_proxy --host 0.0.0.0 --tracker 0.0.0.0:9190 --no-fork


###########################################
# Set Tuning Options
# ------------------

#### DEVICE CONFIG ####

# Set to be address of tvm proxy.
proxy_host = os.environ["TVM_IOS_RPC_PROXY_HOST"]
# Set your desination via env variable.

# Should in format "platform=iOS,id=<the test device uuid>"
destination = os.environ["TVM_IOS_RPC_DESTINATION"]

device_key = 'iphone'
proxy_port = 9090

arch = "arm64"
sdk = "iphoneos"
target = "llvm -mtriple=%s-apple-darwin" % arch
target_host = "llvm -mtriple=%s-apple-darwin" % arch

#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32'

autotvm.measure.measure_methods.check_remote = lambda *args: True

def fcompile(*args):
    xcode.create_dylib(*args, arch=arch, sdk=sdk)
    path = args[0]
    xcode.codesign(path)
    xcode.popen_test_rpc(proxy_host, proxy_port, device_key, destination=destination, libs=[path])

fcompile.output_format = "dylib"

tuning_option = {
    'log_filename': log_file,
    'tuner': 'random',
    'early_stopping': None,
    'n_trial': 100,

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(
            n_parallel=1,
            build_func=fcompile,
            timeout=60
        ),
        runner=autotvm.RPCRunner(
            device_key, host='127.0.0.1', port=9190,
            number=20, repeat=3, timeout=60, min_repeat_ms=150)
    ),
}

###################################################################
# Begin Tuning
# ------------

def tune_tasks(tasks,
               measure_option,
               tuner='random',
               n_trial=1000,
               early_stopping=None,
               log_filename='tuning.log',
               use_transfer_learning=True):
    # create tmp log file
    tmp_log_file = log_filename + ".tmp"
    if os.path.exists(tmp_log_file):
        os.remove(tmp_log_file)

    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))

        # create tuner
        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(tsk, loss_type='rank')
        elif tuner == 'xgb_knob':
            tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
        elif tuner == 'ga':
            tuner_obj = GATuner(tsk, pop_size=50)
        elif tuner == 'random':
            tuner_obj = RandomTuner(tsk)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(tsk)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        if use_transfer_learning:
            if os.path.isfile(tmp_log_file):
                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
        # do tuning
        tsk_trial = min(n_trial, len(tsk.config_space))
        tuner_obj.tune(n_trial=tsk_trial,
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                           autotvm.callback.log_to_file(tmp_log_file)
                       ])

    # pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_filename)
    os.remove(tmp_log_file)

########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.

def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, input_shape, _ = get_network(network, batch_size=1)
    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                              params=params,
                                              ops=(relay.op.get("nn.conv2d"),))

    # run tuning tasks
    print("Tuning...")
    tune_tasks(tasks, **tuning_opt)

    # compile kernels with history best records
    with autotvm.apply_history_best(log_file):
        print("Compile...")
        with tvm.transform.PassContext(opt_level=3):
            graph, lib, params = relay.build_module.build(
                mod, target=target, params=params)

        # export library
        path_dso = "tuned_deploy.dylib"
        lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk)
        xcode.codesign(path_dso)

        # Evaluate inference cost on tuned lib
        xcode.popen_test_rpc(proxy_host, proxy_port, device_key, destination=destination, libs=[path_dso])

        remote = autotvm.measure.request_remote(device_key, '0.0.0.0', 9190,
                                                timeout=10000)

        # Upload not needed for ios because dylib is built into app
        # remote.upload(path_dso)

        rlib = remote.load_module(path_dso)

        ctx = remote.cpu(0)
        
        module = runtime.create(graph, rlib, ctx)
        data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
        module.set_input('data', data_tvm)
        module.set_input(**params)

        # evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=20)
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
              (np.mean(prof_res), np.std(prof_res)))

# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
if __name__ == '__main__':
    if os.path.exists("rpc_config.txt"):
        os.remove("rpc_config.txt")
    tune_and_evaluate(tuning_option)

######################################################################
# Sample Output
# -------------

Hi @kazum, thank you for your suggestion.I’m able to start the RPC tracker.

I’m also tried as same as @jacobpostman mostly, but while tuning the sample model I got the below logs.

[Task  1/16]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (15/100) | 600.01 s[19:07:37] /Users/Dileep/LatestTVM/22_08/tvm/src/runtime/metal/metal_device_api.mm:137: Intializing Metal device 0, name=Intel Iris Pro Graphics
[Task  1/16]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (16/100) | 640.05 s[19:08:18] /Users/Dileep/LatestTVM/22_08/tvm/src/runtime/metal/metal_device_api.mm:137: Intializing Metal device 0, name=Intel Iris Pro Graphics
[Task  1/16]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (17/100) | 680.11 s[19:08:57] /Users/Dileep/LatestTVM/22_08/tvm/src/runtime/metal/metal_device_api.mm:137: Intializing Metal device 0, name=Intel Iris Pro Graphics
[Task  1/16]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (18/100) | 720.17 s[19:09:38] /Users/Dileep/LatestTVM/22_08/tvm/src/runtime/metal/metal_device_api.mm:137: Intializing Metal device 0, name=Intel Iris Pro Graphics
[Task  1/16]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (19/100) | 760.24 s[19:10:18] /Users/Dileep/LatestTVM/22_08/tvm/src/runtime/metal/metal_device_api.mm:137: Intializing Metal device 0, name=Intel Iris Pro Graphics

I always see " 0.00/ 0.00 GFLOPS" while tuning the model. whether the model is tuning properly or missing some thing while tuning the model ?

here is the part of code

target = 'metal'
proxy_port = 9090
key = "iphone"
arch = "arm64"
sdk = "iphoneos"
target_host = "llvm -mtriple=%s-apple-darwin" % arch

@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src):
    return xcode.compile_metal(src, sdk=sdk)

#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.log" % network
dtype = 'float32'

autotvm.measure.measure_methods.check_remote = lambda *args: True

def fcompile(*args):
    from tvm.contrib import xcode
    xcode.create_dylib(*args, arch=arch, sdk=sdk)
    path = args[0]
    xcode.codesign(path)
    xcode.popen_test_rpc(proxy_host, proxy_port, key,
                     destination=destination,
                     libs=[path])

fcompile.output_format = "dylib"

tuning_option = {
    'log_filename': log_file,
    'tuner': 'xgb',
    'early_stopping': None,

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(
            n_parallel=1,
            build_func=fcompile,
            timeout=60
        ),
        runner=autotvm.RPCRunner(
            key, host='127.0.0.1', port=9190,
            number=20, repeat=3, timeout=60, min_repeat_ms=150)
    ),
}
def tune_tasks(tasks,
               measure_option,
               tuner='xgb',
               n_trial=100,
               early_stopping=None,
               log_filename='tuning.log',
               use_transfer_learning=False):
    # create tmp log file
    tmp_log_file = log_filename + ".tmp"
    if os.path.exists(tmp_log_file):
        os.remove(tmp_log_file)

    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " %(i+1, len(tasks))

        # create tuner
        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(tsk, loss_type='rank')
        elif tuner == 'ga':
            tuner_obj = GATuner(tsk, pop_size=100)
        elif tuner == 'random':
            tuner_obj = RandomTuner(tsk)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(tsk)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        if use_transfer_learning:
            if os.path.isfile(tmp_log_file):
                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

        # do tuning
        tsk_trial = min(n_trial, len(tsk.config_space))
        tuner_obj.tune(n_trial=tsk_trial,
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                           autotvm.callback.log_to_file(tmp_log_file)
                       ])

    # pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_filename)
    os.remove(tmp_log_file)


########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.

def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, input_shape, out_shape = get_network(network, batch_size=1)

    tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                              params=params,
                                              target_host=target_host,
                                              ops=(relay.op.get("nn.conv2d"),))

    # run tuning tasks
    print("Tuning...")
    tune_tasks(tasks, **tuning_opt)

Please let me know if anything i’m missing on the above code.

Thanks,

Hi @kazum,

I enabled the debugging logs of Autotune TVM, it is failing to tune the model on iOS device. can you please have a look into the below debugging logs

Extract tasks...
Tuning...
Get devices for measurement successfully!
[15:27:22] /Users/Dileep/LatestTVM/31_08/tvm/src/runtime/metal/metal_device_api.mm:137: Intializing Metal device 0, name=Intel Iris Pro Graphics
No: 1	GFLOPS: 0.00/0.00	result: MeasureResult(costs=('Cannot request iphone after 5 retry, last_error:Traceback (most recent call last):\n  [bt] (4) 5   libtvm.dylib                        0x0000000113c86f26 TVMFuncCall + 70\n  [bt] (3) 4   libtvm.dylib                        0x0000000113cdf3b0 std::__1::__function::__func<tvm::runtime::$_0, std::__1::allocator<tvm::runtime::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 336\n  [bt] (2) 3   libtvm.dylib                        0x0000000113cde753 tvm::runtime::RPCClientConnect(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, tvm::runtime::TVMArgs) + 99\n  [bt] (1) 2   libtvm.dylib                        0x0000000113cdd206 tvm::runtime::RPCConnect(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, tvm::runtime::TVMArgs) + 390\n  [bt] (0) 1   libtvm.dylib                        0x0000000112f92641 dmlc::LogMessageFatal::~LogMessageFatal() + 113\n  File "/Users/Dileep/LatestTVM/31_08/tvm/src/runtime/rpc/rpc_socket_impl.cc", line 73\nTVMError: Check failed: sock.Connect(addr): Connect to 127.0.0.1:9090 failed',), error_no=7, all_cost=60, timestamp=1599127077.6432252)	[('tile_f', [-1, 1, 128, 1]), ('tile_y', [-1, 7, 1, 1]), ('tile_x', [-1, 1, 7, 1]), ('tile_rc', [-1, 1]), ('tile_ry', [-1, 1]), ('tile_rx', [-1, 1]), ('auto_unroll_max_step', 1500), ('unroll_explicit', 0)],None,65389
[15:28:01] /Users/Dileep/LatestTVM/31_08/tvm/src/runtime/metal/metal_device_api.mm:137: Intializing Metal device 0, name=Intel Iris Pro Graphics

I always see the above logs while tuning the model, can you please suggest to me what is wrong in the tuning?

Hi @jacobpostman, did you get any solution for the above issue.

Thanks,

Hi @Dileep -

There were a couple of problems I was running into:

  1. The xcode build process initiated by the popen_test_rpc() was being terminated prematurely when run in the fcompile build_func.
  2. The temp dylib filename with full path was not accessible from the ios rpc runner. As a hack to test, I removed the tmp_dir and just created the dylib in the local run directory.
  3. remote.upload() should not be called because the dylib is built into the app

Following the code in my previous post, making the following changes should be all you need to get started. I’m going to look at a real solution that’s not just a quick hack, but this may get you up and running.

def fcompile(*args):
    print("\nCalling fcompile. args[0] = %s" % args[0])
    xcode.create_dylib(*args, arch=arch, sdk=sdk)
    path = args[0]
    xcode.codesign(path)
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index 9cef674d3..784184e61 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -426,8 +426,9 @@ class _WrappedBuildFunc():
         """
         tic = time.time()
         try:
-            filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
-                getrandbits(64), self.build_func.output_format))
+            # filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
+            #     getrandbits(64), self.build_func.output_format))
+            filename = "tmp_func_%0x.%s" % (getrandbits(64), self.build_func.output_format)
             # TODO(tvm-team) consider linline _build_func_common
             func, arg_info = _build_func_common(measure_input, **kwargs)
             func.export_library(filename, self.build_func)
@@ -485,6 +486,8 @@ def run_through_rpc(measure_input, build_result,
     errno = MeasureErrorNo.NO_ERROR
     try:
         # upload built module
+        from tvm.contrib import xcode
+        xcode.popen_test_rpc(os.environ["TVM_IOS_RPC_PROXY_HOST"], 9090, "iphone", destination=os.environ["TVM_IOS_RPC_DESTINATION"], libs=[os.path.split(build_result.filename)[1]])
         remote = request_remote(*remote_args)
         # Program the FPGA every single time when targeting VTA
         if hasattr(measure_input.target, 'device_name') and \
@@ -493,7 +496,7 @@ def run_through_rpc(measure_input, build_result,
             from vta import program_fpga, reconfig_runtime
             program_fpga(remote, None)
             reconfig_runtime(remote)
-        remote.upload(build_result.filename)
+        # remote.upload(build_result.filename)
         func = remote.load_module(os.path.split(build_result.filename)[1])
         ctx = remote.context(str(measure_input.target), 0)

Hi @jacobpostman, Thank you for your suggestion. I tried with the above changes whichever you noted, but my bad while tuning the model I got the Xcode build error, I don’t know what is the cause of the below error

[Task  1/12]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (0/100) | 0.00 sxcodebuild: error: Unable to find a destination matching the provided destination specifier:
		{ platform:iOS, id:00008020-001514A03E99002E }

	The requested device could not be found because no available devices matched the request.

	Available destinations for the "tvmrpc" scheme:
		{ platform:macOS, arch:x86_64h, id:0965DDFC-9045-5428-9A4C-481CD59D4C72 }
		{ platform:macOS, arch:x86_64, id:0965DDFC-9045-5428-9A4C-481CD59D4C72 }
		{ platform:macOS, arch:i386, id:0965DDFC-9045-5428-9A4C-481CD59D4C72 }
		{ platform:tvOS Simulator, id:491A1CC1-2C70-42B0-8C6F-AA0DC91B8E87, OS:12.1, name:Apple TV }
		{ platform:tvOS Simulator, id:B55611F5-2EB2-439E-9296-E0CADCA6FC68, OS:12.1, name:Apple TV 4K }
		{ platform:tvOS Simulator, id:4060BF62-506E-4E64-9335-55AD37A0F044, OS:12.1, name:Apple TV 4K (at 1080p) }
		{ platform:watchOS Simulator, id:B314B495-123F-48BC-9662-BCC8077E465B, OS:5.1, name:Apple Watch Series 2 - 38mm }
		{ platform:watchOS Simulator, id:8C261AE9-B004-4877-B331-1BA12A8A2E77, OS:5.1, name:Apple Watch Series 2 - 42mm }
		{ platform:watchOS Simulator, id:1092E5D2-DFAC-436A-97F2-93C9F65AFC28, OS:5.1, name:Apple Watch Series 3 - 38mm }
		{ platform:watchOS Simulator, id:92DDACA5-6154-4D00-9958-BD5982F043AF, OS:5.1, name:Apple Watch Series 3 - 42mm }
		{ platform:watchOS Simulator, id:5C532166-073B-429C-BEDC-FA4C82DD4B90, OS:5.1, name:Apple Watch Series 4 - 40mm }
		{ platform:watchOS Simulator, id:EA4BD647-0FEF-44BA-BD10-10F7E12DD262, OS:5.1, name:Apple Watch Series 4 - 44mm }
		{ platform:iOS Simulator, id:2BA73EBC-5F9D-4223-B19C-8CFFA63CDC5E, OS:12.1, name:iPad (5th generation) }
		{ platform:iOS Simulator, id:1ED2580F-CFC1-47EB-8E72-F9CAA67AD975, OS:12.1, name:iPad (6th generation) }
		{ platform:iOS Simulator, id:0C4FF287-2331-47F6-8F34-6C52B6C2375D, OS:12.1, name:iPad Air }
		{ platform:iOS Simulator, id:28C1076C-A55B-4A35-B016-D14FEB48E0CD, OS:12.1, name:iPad Air 2 }
		{ platform:iOS Simulator, id:F10C8D9F-33F5-4F94-85F4-6408D51016F2, OS:12.1, name:iPad Pro (9.7-inch) }
		{ platform:iOS Simulator, id:E477E30E-03AD-4783-8E6C-3E01A8BFCDA2, OS:12.1, name:iPad Pro (10.5-inch) }
		{ platform:iOS Simulator, id:723A360B-9EBB-46FC-9D16-65A2A40B4865, OS:12.1, name:iPad Pro (11-inch) }
		{ platform:iOS Simulator, id:D9BBE5B1-A6B5-4F6F-958F-AD1C44FFE5CA, OS:12.1, name:iPad Pro (12.9-inch) }
		{ platform:iOS Simulator, id:F68E8EF5-297A-4329-84CC-02AED15D14A8, OS:12.1, name:iPad Pro (12.9-inch) (2nd generation) }
		{ platform:iOS Simulator, id:FAD0A9FA-12DD-44D9-809F-366B87A394AE, OS:12.1, name:iPad Pro (12.9-inch) (3rd generation) }
		{ platform:iOS Simulator, id:40D706E8-5B44-46BC-82E2-07D518A3B7A9, OS:12.1, name:iPhone 5s }
		{ platform:iOS Simulator, id:56173D15-C370-49DF-9B1E-B6F2C4D1BF24, OS:12.1, name:iPhone 6 }
		{ platform:iOS Simulator, id:F4EBE651-3F39-4B27-8B0A-AAE4C39EC727, OS:12.1, name:iPhone 6 Plus }
		{ platform:iOS Simulator, id:14A8DF52-13B3-40EE-A340-002D75F9A8CF, OS:12.1, name:iPhone 6s }
		{ platform:iOS Simulator, id:CAAA0B53-6994-44E8-B6FC-C3AB5B3454FB, OS:12.1, name:iPhone 6s Plus }
		{ platform:iOS Simulator, id:4CC104C8-8B6D-4C87-B7D1-287F0F4C5315, OS:12.1, name:iPhone 7 }
		{ platform:iOS Simulator, id:08119D0F-8737-4749-A6AD-76A558413FCF, OS:12.1, name:iPhone 7 Plus }
		{ platform:iOS Simulator, id:9E744E61-8E50-48DC-AD37-6E91EF6E941D, OS:12.1, name:iPhone 8 }
		{ platform:iOS Simulator, id:A9FF7E04-1E26-4675-B983-7214F8E1EFD3, OS:12.1, name:iPhone 8 Plus }
		{ platform:iOS Simulator, id:D74ECB78-618E-4875-A839-CACA6B09EE2F, OS:12.1, name:iPhone SE }
		{ platform:iOS Simulator, id:7682392E-A0E3-45E8-A120-3E87CDC8509F, OS:12.1, name:iPhone X }
		{ platform:iOS Simulator, id:CBBB98EF-0C3A-4973-BCA3-7337D2466294, OS:12.1, name:iPhone Xs }
		{ platform:iOS Simulator, id:35ED32F1-FCC5-4D96-A894-2EFD3555BF70, OS:12.1, name:iPhone Xs Max }
		{ platform:iOS Simulator, id:09FFC6F0-818D-424F-B984-67D2208D0918, OS:12.1, name:iPhone Xʀ }

I’m given a proper destination address(UUID) of the iPhone.

Working environment: iOS version - 13.7, Xcode - 10.2(Downloaded device supported files of iOS 13.7 from GitHub and copied into Xcode package contents, able to build the iOS RPC app with the latest version of iOS ), macOS: Catalina(10.15.5).

Maybe for the above issue, I need to give a try with the upgrading of Xcode with the latest version.

@kazum, any suggestions on the above issues.

Thanks

@Dileep - Are you able to install and run the tvmrpc app to the device from xcode? It may just be that you need to give the app permission to be installed on your device. On the iPhone you do this by going to Settings -> General -> Device Management -> Select Developer App -> Verify App

Hi @jacobpostman,

Yes, I’m able to install tvmrpc app to the device from Xcode, while installing the app I’m giving all the permission whichever you are noted but still, the issue is the same. if I connect iPhone to MAC via the USB cable(RPC connected), I’m not getting the above error, but while tuning the model I’m getting ** BUILD INTERRUPTED ** and every time the tvmrpc app is closing and reopening(maybe the app is crashing) please see the below logs

	bootArgs:                  (null)
		connected:                 yes
		isWirelessEnabled:         no
		connectionType:            direct
		hostname:                  (null)
		bonjourServiceName:        30:57:14:9c:4e:a0@fe80::3257:14ff:fe9c:4ea0._apple-mobdev2._tcp.local.
		} (13.7 (17H35))
Test Suite 'All tests' started at 2020-09-05 10:53:11.863
Test Suite 'tvmrpcLauncher.xctest' started at 2020-09-05 10:53:11.863
Test Suite 'tvmrpcLauncher' started at 2020-09-05 10:53:11.863
Test Case '-[tvmrpcLauncher testRPC]' started.
** BUILD INTERRUPTED **
[Task  1/12]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (4/100) | 102.09 s

Thanks

Hi @jacobpostman,

After updating Xcode and iOS with the latest version also I’m facing the same issue. what version you are using Xcode and iOS version for tuning?

Thanks, Dileep

Hi @Dileep - To clarify: the phone needs to be connected via USB for tuning to work. The tvmrpc app is actually being rebuilt and reinstalled at each tuning step with the signed dylib included in the build. (If there is a way to do remote installation over IP, I haven’t used that).

I do still get the BUILD INTERRUPTED messages, but moving the xcode.popen_test_rpc() call allowed the tuning results to get reported. It looked to me like the process started by popen_test_rpc() was getting killed prematurely, but the change I made kept it up long enough for the app -> proxy -> tracker connection to be made.

For each tuning step the dylib gets built, the tvmrpc app gets rebuilt and installed, then the xcode testing framework initiates the rpc proxy connection from the app to the proxy server. The proxy server then exposes the iphone device to the tracker so that the tuning script can request the connection from the tracker.

Again, the workaround I posted was just a hack to try out tuning on iOS. I haven’t had a chance to look into a proper fix, but I imagine it will require something like overriding the run_through_rpc() function when tuning with xcode.

Hi @jacobpostman, thank you for the clarification. I tried to tune the model on iOS target, with the above changes whichever you noted, I’m able to tune the model on the iOS target but while tuning the model every time I see "Current/Best: 0.00/ 0.00 GFLOPS" in the logs. I’ve used the same script whichever you shared ([Auto-TVM] How to Auto tune the model on iOS device).i ran the tuning model for more than an hour, for more info attaching the tuning logs.

Test Suite 'All tests' started at 2020-09-15 18:28:32.249
Test Suite 'tvmrpcLauncher.xctest' started at 2020-09-15 18:28:32.249
Test Suite 'tvmrpcLauncher' started at 2020-09-15 18:28:32.249
Test Case '-[tvmrpcLauncher testRPC]' started.
2020-09-15 18:28:32.344791+0530 tvmrpc[3515:190115] [18:28:32] /Users/Dileep/LatestTVM/22_08/tvm/apps/ios_rpc/tvmrpc/TVMRuntime.mm:146: Load module from /private/var/containers/Bundle/Application/B8353EEF-9E61-40CC-BF6E-80EBDDF824F3/tvmrpc.app/Frameworks/tvm/tmp_func_17cd287b713d0fed.dylib ...
Test Case '-[tvmrpcLauncher testRPC]' passed (0.126 seconds).
Test Suite 'tvmrpcLauncher' passed at 2020-09-15 18:28:32.376.
	 Executed 1 test, with 0 failures (0 unexpected) in 0.126 (0.126) seconds
Test Suite 'tvmrpcLauncher.xctest' passed at 2020-09-15 18:28:32.376.
	 Executed 1 test, with 0 failures (0 unexpected) in 0.126 (0.127) seconds
Test Suite 'All tests' passed at 2020-09-15 18:28:32.376.
	 Executed 1 test, with 0 failures (0 unexpected) in 0.126 (0.127) seconds
[Task 10/12]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (42/100) | 456.39 s

Test session results and logs:
	/Users/Library/Developer/Xcode/DerivedData/tvmrpc-crkmdkqlzonxjwdcvrxtvofnlaup/Logs/Test/Test-tvmrpc-2020.09.15_18-28-23-+0530.xcresult

2020-09-15 18:28:32.398 xcodebuild[35740:545190] [MT] IDETestOperationsObserverDebug: 7.995 elapsed -- Testing started completed.
2020-09-15 18:28:32.398 xcodebuild[35740:545190] [MT] IDETestOperationsObserverDebug: 0.000 sec, +0.000 sec -- start
2020-09-15 18:28:32.398 xcodebuild[35740:545190] [MT] IDETestOperationsObserverDebug: 7.995 sec, +7.995 sec -- end
** TEST SUCCEEDED **

Every time I can see the test got succeded, but there is no change in the GFLOPS. Whether the model is tuning properly or any step I’m missing while tuning the model?

the same way I tried to tune the model on iPhone Metal target, changing the below target configurations

target = 'metal'
proxy_port = 9090
key = "iphone"
arch = "arm64"
sdk = "iphoneos"
target_host = "llvm -mtriple=%s-apple-darwin" % arch

tasks = autotvm.task.extract_from_program(mod["main"], target=target,
                                              target_host= target_host,params=params,
                                              ops=(relay.op.get("nn.conv2d"),))

After changing the above target configuration, tuning is not happening for the iPhone Metal target. any idea what step might be missing here? Have you tried to tune the model on the iPhone Metal target?

Thanks,