Can Autoscheduler support tuning for multiple targets?

Hi,

I have implemented a transform pass to annotate targets on device. I annotated nn.conv2d on device cuda and set fallback device to llvm. But when I extract tasks from relay like this

target = {"llvm": "llvm", "cuda": "cuda"}
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

There are some errors.

Extract tasks...
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/pads/Desktop/tvm_v0.8/python/tvm/auto_scheduler/relay_integration.py", line 71, in call_all_topi_funcs
    opt_mod, _ = relay.optimize(mod_clone, target, params)
  File "/home/pads/Desktop/tvm_v0.8/python/tvm/relay/build_module.py", line 404, in optimize
    target = _update_target(target)
  File "/home/pads/Desktop/tvm_v0.8/python/tvm/relay/build_module.py", line 56, in _update_target
    dev_type = tvm_expr.IntImm("int32", _nd.device(dev).device_type)
  File "/home/pads/Desktop/tvm_v0.8/python/tvm/runtime/ndarray.py", line 276, in device
    return Device(dev_type, dev_id)
  File "/home/pads/Desktop/tvm_v0.8/python/tvm/_ffi/runtime_ctypes.py", line 205, in __init__
    self.device_type = int(device_type)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'Target'

If I use single target, it can extract tasks successfully, so I’m wondering if Autoscheduler can extract tasks and tuning for multiple targets?

I’m not sure how you implemented the pass, but the error seems not related to the auto-scheduler. Can your model work without tuning (regardless the performance)?

Yes, my model can work without tuning. It can extract tasks successfully with one target as well. But, if I pass two targets target = {“llvm”: “llvm”, “cuda”: “cuda”} for tuning, the errors occur as shown above. I think the reason is that in _update_target this function

def _update_target(target):
    target = target if target else Target.current()
    if target is None:
        raise ValueError("Target is not set in env or passed as argument.")

    tgts = {}
    if isinstance(target, (str, Target)):
        dev_type = tvm_expr.IntImm("int32", _nd.device(str(target)).device_type)
        tgts[dev_type] = Target(target)

    elif isinstance(target, dict):
        for dev, tgt in target.items():
            dev_type = tvm_expr.IntImm("int32", _nd.device(dev).device_type)
            tgts[dev_type] = Target(tgt)
              
    else:
        raise TypeError(
            "target is expected to be str or "
            + "tvm.target.Target, but received "
            + "{}".format(type(target))
        )
    
    return tgts

If there is only one target “llvm”, it will update it to {1:llvm -keys=cpu -link-params=0}. If there are two targets{llvm -keys=cpu -link-params=0: ‘llvm’, cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32: ‘cuda’}, it’s type is dict and the first element in dict is a target not string, so the error occurs. So, I’m not sure if Autoscheduler can extract tasks and tuning for multiple targets or just one target? Thank you for your reply.

What confused me was that this function is not specifically for auto-scheduler. It means you should encounter this issue even without auto-scheduler. If not, then it’s likely that the optimization process in task extraction is inconsistent as the one in the normal building process. It would be helpful if you could make a minimal reproducible script.

Here’s my script


import numpy as np

import tvm
from tvm import relay, auto_scheduler
import tvm.relay.testing
from tvm.contrib import graph_executor

def get_network(name, batch_size, layout="NHWC", dtype="float32"):
    """Get the symbol definition and random weight of a network"""

    # auto-scheduler prefers NHWC layout
    if layout == "NHWC":
        image_shape = (224, 224, 3)
    elif layout == "NCHW":
        image_shape = (3, 224, 224)
    else:
        raise ValueError("Invalid layout: " + layout)

    input_shape = (batch_size,) + image_shape
    output_shape = (batch_size, 1000)

    if name.startswith("resnet-"):
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.resnet.get_workload(
            num_layers=n_layer,
            batch_size=batch_size,
            layout=layout,
            dtype=dtype,
            image_shape=image_shape,
        )
    elif name.startswith("resnet3d-"):
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.resnet.get_workload(
            num_layers=n_layer,
            batch_size=batch_size,
            layout=layout,
            dtype=dtype,
            image_shape=image_shape,
        )
    elif name == "mobilenet":
        mod, params = relay.testing.mobilenet.get_workload(
            batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
        )
    elif name == "squeezenet_v1.1":
        assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
        mod, params = relay.testing.squeezenet.get_workload(
            version="1.1",
            batch_size=batch_size,
            dtype=dtype,
            image_shape=image_shape,
        )
    elif name == "inception_v3":
        input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3)
        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == "mxnet":
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model

        assert layout == "NCHW"

        block = get_model("resnet18_v1", pretrained=True)
        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
        net = mod["main"]
        net = relay.Function(
            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
        )
        mod = tvm.IRModule.from_expr(net)

    return mod, params, input_shape, output_shape


# Define the neural network and compilation target
network = "resnet-18"
batch_size = 1
layout = "NCHW"
target = {"llvm": "llvm", "cuda": "cuda"}
dtype = "float32"
log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, "cuda")

mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype)

@relay.transform.function_pass(opt_level=1)
class MyPass:
    def __init__(self):
        self.var = 0
    # This function can define a pass.
    def transform_function(self, func, mod, ctx):
        obj = self
        class Test(tvm.relay.ExprMutator):
            def visit_call(self, expr):
                visit = super().visit_call(expr)
                if expr.op == tvm.relay.op.get("nn.conv2d"):
                    return relay.annotation.on_device(visit, 'cuda')             
                else:
                    return visit

        return Test().visit(func)

dev1 = tvm.device("llvm")
dev2 = tvm.device("cuda")
custom_pass = MyPass()
mod = custom_pass(mod)

#print("Extract tasks...")
''' tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

for idx, task in enumerate(tasks):
    print("========== Task %d  (workload key: %s) ==========" % (idx, task.workload_key))
    print(task.compute_dag) '''


''' def run_tuning():
    print("Begin tuning...")
    measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10)
    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=200,  # change this to 20000 to achieve the best performance
        runner=measure_ctx.runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )
    tuner.tune(tune_option) '''


#run_tuning()

# Compile with the history best
print("Compile...")
#with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(opt_level=3):    #, config={"relay.backend.use_auto_scheduler": True}
    lib = relay.build(mod, target=target, params=params)

# Create graph executor

module = graph_executor.create(lib.get_graph_json(), lib.get_lib(), [dev1, dev2])
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input(**lib.get_params())
module.set_input("data", data_tvm)
module.run()

Without tuning, my code can execute with no errors. However, the errors occur when extracting tasks. Maybe you’re right, the optimization process might be different.

Turns out that the task extraction needs to be updated to process the target in a correct way. I’ll look into this.