[Auto-TVM][VTA][Colab] Can I apply Auto-tuning algorithms to ResNet ONNX model compiled with VTA instead of LLVM?

Here was the problem: First time I followed this tutorial and used default auto-tuning methods (XGBTuner, GATuner, etc.). Everything was OK. But I would like to change the target from LLVM to VTA so I made some modifications to the original code as follows:

# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env()

# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
# Set ``device=arm_cpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu

# Name of Gluon model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA.
network = "resnet18-v2-7"
network_onnx = network + ".onnx"
print(network_onnx)

model_url = "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/" + network_onnx
model_path = download_testdata(model_url, network_onnx, module="onnx")
onnx_model = onnx.load(model_path)

start_name = "nn.max_pool2d"
stop_name = "nn.global_avg_pool2d"

print(env.TARGET)
print(target.device_name)

remote = rpc.LocalSession()

# Get execution context from remote
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)

# Load pre-configured AutoTVM schedules
with autotvm.tophub.context(target):

    # Populate the shape and data type dictionary for ImageNet classifier input
    dtype_dict = {"data": "float32"}
    shape_dict = {"data": (env.BATCH, 3, 224, 224)}

    # Get off the shelf gluon model, and convert to relay
    mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

    # Update shape and type dictionary
    shape_dict.update({k: v.shape for k, v in params.items()})
    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})

    # Perform quantization in Relay
    # Note: We set opt_level to 3 in order to fold batch norm
    with tvm.transform.PassContext(opt_level=3):
        with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
            mod = relay.quantize.quantize(mod, params=params)
        # Perform graph packing and constant folding for VTA target
        assert env.BLOCK_IN == env.BLOCK_OUT
        # do device annotation if target is intelfocl or sim
        relay_prog = graph_pack(
            mod["main"],
            env.BATCH,
            env.BLOCK_OUT,
            env.WGT_WIDTH,
            start_name=start_name,
            stop_name=stop_name,
            device_annot=(env.TARGET == "intelfocl"),
        )
      
    with vta.build_config(
        opt_level=3, disabled_pass={"AlterOpLayout", "tir.CommonSubexprElimTIR"}
    ):
        graph, lib, params = relay.build(
            relay_prog, target=tvm.target.Target(target, host=env.target_host), params=params
        )

    # Send the inference library over to the remote RPC server
    temp = utils.tempdir()
    lib.export_library(temp.relpath("graphlib.tar"))
    remote.upload(temp.relpath("graphlib.tar"))
    lib = remote.load_module("graphlib.tar")

    m = graph_executor.create(graph, lib, ctx)

The previous piece of code could normally run but after I added the tuning option and decided to use XGBTuner for auto-tuning the ResNet, it kept running but showed no results:

tracker_host = os.environ.get("TVM_TRACKER_HOST", "127.0.0.1")
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))

tuning_option = {
    "tuner": "xgb",
    "trials": 1000,
    "early_stopping": None,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(), 
        runner=autotvm.RPCRunner(
            env.TARGET,
            host=tracker_host,
            port=tracker_port,
            number=5,
            timeout=60,
            module_loader=vta.module_loader(),
            # check_correctness=True, # TODO: re-enable when check_correctness works again.
        ),
    )
}

logfile = network + "_vta_" + tuning_option["tuner"] + ".json"
print(logfile)

# begin by extracting the tasks from the onnx model
tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params)

# Tune the extracted tasks sequentially.
for i, task in enumerate(tasks):
    prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))

    if tuning_option["tuner"] == "xgb":
        tuner_obj = XGBTuner(task, loss_type="rank")
    elif tuning_option["tuner"] == "ga":
        tuner_obj = GATuner(task, pop_size=50)
    elif tuning_option["tuner"] == "random":
        tuner_obj = RandomTuner(task)
    elif tuning_option["tuner"] == "gridsearch":
        tuner_obj = GridSearchTuner(task)

    tuner_obj.tune(
        n_trial=min(tuning_option["trials"], len(task.config_space)),
        early_stopping=tuning_option["early_stopping"],
        measure_option=tuning_option["measure_option"],
        callbacks=[
            autotvm.callback.progress_bar(tuning_option["trials"], prefix=prefix),
            autotvm.callback.log_to_file(logfile),
        ],
    )

Here was the screenshot: (I did not use PYNQ and I did all the experiments on Google Colab)

1 Like

I also tried to use the LocalRunner in tuning_option:

tuning_option = {
    "tuner": "xgb",
    "trials": 1000,
    "early_stopping": None,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(timeout=10), 
        runner = autotvm.LocalRunner(
          timeout=10, 
          number=4, 
          repeat=3, 
          min_repeat_ms=0, 
          cooldown_interval=0.1, 
          enable_cpu_cache_flush=False, 
          module_loader=None
        )
    )
}

logfile = network + "_vta_" + tuning_option["tuner"] + ".json"
print(logfile)

# begin by extracting the tasks from the onnx model
tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params)

It had this problem:

[Task  1/43]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (0/1000) | 0.00 sException in thread Thread-26:
Traceback (most recent call last):
  File "/usr/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/content/drive/My Drive/VTA-Project/tvm/python/tvm/autotvm/measure/measure_methods.py", line 812, in _check
    while not dev.exist:  # wait until we get an available device
  File "/content/drive/My Drive/VTA-Project/tvm/python/tvm/_ffi/runtime_ctypes.py", line 263, in exist
    return self._GetDeviceAttr(self.device_type, self.device_id, 0) != 0
  File "/content/drive/My Drive/VTA-Project/tvm/python/tvm/_ffi/runtime_ctypes.py", line 247, in _GetDeviceAttr
    return tvm.runtime._ffi_api.GetDeviceAttr(device_type, device_id, attr_id)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIUlNS0_7TVMArgsEPNS0_11TVMRetValueEE0_EEE4CallEPKS1_S4_S6_
  5: tvm::runtime::RPCDeviceAPI::GetAttr(DLDevice, tvm::runtime::DeviceAttrKind, tvm::runtime::TVMRetValue*)
  4: non-virtual thunk to tvm::runtime::RPCClientSession::GetAttr(DLDevice, tvm::runtime::DeviceAttrKind, tvm::runtime::TVMRetValue*)
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::RPCEndpoint::Init()::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)
  1: tvm::runtime::SockChannel::Recv(void*, unsigned long)
  0: tvm::support::Socket::Error(char const*)
  File "/content/drive/MyDrive/VTA-Project/tvm/src/runtime/rpc/../../support/socket.h", line 362
TVMError: Socket SockChannel::Recv Error:Connection reset by peer

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-36-a97ad8762603> in <module>()
     22         callbacks=[
     23             autotvm.callback.progress_bar(tuning_option["trials"], prefix=prefix),
---> 24             autotvm.callback.log_to_file(logfile),
     25         ],
     26     )

4 frames
/content/drive/My Drive/VTA-Project/tvm/python/tvm/autotvm/measure/measure_methods.py in set_task(self, task)
    324         else:
    325             raise RuntimeError(
--> 326                 "Cannot get remote devices from the tracker. "
    327                 "Please check the status of tracker by "
    328                 "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "

RuntimeError: Cannot get remote devices from the tracker. Please check the status of tracker by 'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' and make sure you have free devices on the queue status.

It seems not easy to solve.