AutoScheduler on mali GPU raises "TypeError: conv2d_winograd_nhwc_mali() tasks from 6 to 7 positional arguments but 8 were given"

Hi, everybody. I have tried to use AutoScheduler on mali gpu, to tune vgg model. It sames the Relay select topi.nn.conv2d_winograd_nhwc for the vgg model. But it raises the below error:

Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/auto_scheduler/relay_integration.py", line 67, in call_all_topi_funcs
    opt_mod, _ = relay.optimize(mod, target, params)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/build_module.py", line 400, in optimize
    mod, params = bld_mod.optimize(mod, target, params)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/build_module.py", line 187, in optimize
    mod = self._optimize(mod, target)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
TypeError: Traceback (most recent call last):
  28: TVMFuncCall
  27: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#10}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  26: tvm::relay::backend::RelayBuildModule::Optimize(tvm::IRModule, tvm::runtime::Map<tvm::Integer, tvm::Target, void, void> const&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
  25: tvm::transform::Pass::operator()(tvm::IRModule) const
  24: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  22: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  21: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  20: tvm::relay::alter_op_layout::AlterOpLayout(tvm::RelayExpr const&)
  19: tvm::relay::ForwardRewrite(tvm::RelayExpr const&, tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)> const&, std::function<tvm::runtime::ObjectRef (tvm::relay::Call const&)>, std::function<tvm::RelayExpr (tvm::RelayExpr const&)>)
  18: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  17: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  16: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9Re
  15: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  14: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  13: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
  12: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  11: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  10: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  9: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9Re
  8: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  7: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
  5: tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)
  4: tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)
  3: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  2: tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)
  1: tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/op/nn/_nn.py", line 184, in alter_op_layout_conv2d
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
  File "<decorator-gen-58>", line 2, in conv2d_alter_layout
  File "/home/shaohaizhu/workspace/tvm/python/tvm/target/generic_func.py", line 275, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/topi/mali/conv2d.py", line 489, in _alter_conv2d_layout
    relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/backend/compile_engine.py", line 209, in select_implementation
    outs = best_plevel_impl.compute(attrs, inputs, out_type) # haizhu.shao 调用conv2d.py中的conv2d_nhwc函数
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/op/op.py", line 90, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#4}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/op/strategy/generic.py", line 240, in _compute_conv2d
    return [topi_compute(*args)]
  File "<decorator-gen-60>", line 2, in conv2d_winograd_nhwc
  File "/home/shaohaizhu/workspace/tvm/python/tvm/target/generic_func.py", line 275, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
TypeError: conv2d_winograd_nhwc_mali() takes from 6 to 7 positional arguments but 8 were given

Then I check the function @conv2d_winograd_nhwc.register([“mali”]) in “python/tvm/topi/mali/conv2d.py” and @tvm.target.generic_func def conv2d_winograd_nhwc in “python/tvm/topi/nn/conv2d.py”, it seems the function in “python/tvm/topi/mali/conv2d.py” has lost one argument “auto_scheduler_rewritten_layout=”"", I add this argument, and then tried to use AutoScheduler to auto tunning.

I have no idea whether it is right about my solution, is there anyone also met the same problem, and had fixed it.

@merrymercy @comaniac

@merrymercy @comaniac @FrozenGene Hi, could you help me about this problem. I have try to fixed it by below solution: modify the function in python/tvm/topi/mali/conv2d.py, add an argument auto_scheduler_rewritten_layout:

@conv2d_winograd_nhwc.register(["mali"])
def conv2d_winograd_nhwc_mali(
    data, weight, strides, padding, dilation, out_dtype, pre_computed=False, auto_scheduler_rewritten_layout=""
):
    """Conv2D Winograd in NHWC layout.
    This is a clean version to be used by the auto-scheduler for mali.
    """
    tile_size = _pick_tile_size(data, weight, layout="NHWC")
    return _conv2d_winograd_nhwc_impl(
        data, weight, strides, padding, dilation, out_dtype, tile_size, pre_computed, auto_scheduler_rewritten_layout
    )

Then it can successfully auto-tune by AutoScheduler, but when the tuning finished, when i call the below command, it will failed, and the error message list as below:

    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
            lib = relay.build(mod, target=target, target_host=target_host, params=params)

Error Message is:

Traceback (most recent call last):
  File "./tune_vgg.py", line 146, in <module>
    run()
  File "./tune_vgg.py", line 143, in run
    lib = tune()
  File "./tune_vgg.py", line 96, in tune
    lib = relay.build(mod, target=target, target_host=target_host, params=params)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/build_module.py", line 333, in build
    mod=ir_mod, target=target, params=params, executor=executor
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/build_module.py", line 148, in build
    self._build(mod, target, target_host, executor)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
ValueError: Traceback (most recent call last):
  29: TVMFuncCall
  28: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  27: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)
  26: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::GraphExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  25: tvm::relay::backend::GraphExecutorCodegen::Codegen(tvm::relay::Function)
  24: tvm::relay::backend::MemoizedExprTranslator<std::vector<tvm::relay::backend::GraphNodeRef, std::allocator<tvm::relay::backend::GraphNodeRef> > >::VisitExpr(tvm::RelayExpr const&)
  23: _ZZN3tvm5relay11ExprFunctorIFSt6vectorINS0_7backend12GraphNodeRefESaI
  22: tvm::relay::backend::GraphExecutorCodegen::VisitExpr_(tvm::relay::CallNode const*)
  21: tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::relay::CompileEngine&, tvm::relay::CCacheKey&>(tvm::relay::CompileEngine&, tvm::relay::CCacheKey&) const
  20: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::relay::CachedFunc (tvm::relay::CompileEngine, tvm::relay::CCacheKey)>::AssignTypedLambda<tvm::relay::{lambda(tvm::relay::CompileEngine, tvm::relay::CCacheKey)#9}>(tvm::relay::{lambda(tvm::relay::CompileEngine, tvm::relay::CCacheKey)#9}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  19: tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)
  18: tvm::relay::CreateSchedule(tvm::relay::Function const&, tvm::Target const&)
  17: tvm::relay::ScheduleGetter::Create(tvm::relay::Function const&)
  16: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  15: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  14: _ZZN3tvm5relay11ExprFunctorIFNS_7runtime5ArrayINS_2te6TensorEvEERKNS_
  13: tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)
  12: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  11: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  10: _ZZN3tvm5relay11ExprFunctorIFNS_7runtime5ArrayINS_2te6TensorEvEERKNS_
  9: tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)
  8: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  7: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relay11ExprFunctorIFNS_7runtime5ArrayINS_2te6TensorEvEERKNS_
  5: tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)
  4: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  3: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  2: _ZZN3tvm5relay11ExprFunctorIFNS_7runtime5ArrayINS_2te6TensorEvEERKNS_
  1: tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/backend/compile_engine.py", line 304, in lower_call
    best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/backend/compile_engine.py", line 209, in select_implementation
    outs = best_plevel_impl.compute(attrs, inputs, out_type) # haizhu.shao 调用conv2d.py中的conv2d_nhwc函数
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/op/op.py", line 90, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#4}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
  0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/relay/op/strategy/generic.py", line 240, in _compute_conv2d
    return [topi_compute(*args)]
  File "/home/shaohaizhu/workspace/tvm/python/tvm/topi/nn/conv2d.py", line 1261, in conv2d_winograd_nhwc_without_weight_transform
    auto_scheduler_rewritten_layout=auto_scheduler_rewritten_layout,
  File "<decorator-gen-60>", line 2, in conv2d_winograd_nhwc
  File "/home/shaohaizhu/workspace/tvm/python/tvm/target/generic_func.py", line 275, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/home/shaohaizhu/workspace/tvm/python/tvm/topi/mali/conv2d.py", line 593, in conv2d_winograd_nhwc_mali
    data, weight, strides, padding, dilation, out_dtype, tile_size, pre_computed, auto_scheduler_rewritten_layout
  File "/home/shaohaizhu/workspace/tvm/python/tvm/topi/nn/conv2d.py", line 1062, in _conv2d_winograd_nhwc_impl
    H_CAT, W_CAT, CO, CI = get_const_tuple(weight.shape)
ValueError: too many values to unpack (expected 4)

So this error is occured during replay? Does the tuning process works?

Unfortunately, currently I don’t have a mali device to test this problem. This seems like a bug of “layout_rewrite” feature in Ansor. Add disabled_pass={"AutoSchedulerLayoutRewrite"} to the tvm.transform.PassContext(...) to see if this can pass.

We may not able to get a best performance with the layout rewrite features disabled, so the best solution is still to figure out the reason of this bug. You can try to debug the op strategy of conv2d on mali.

@jcf94 Thanks for your help. After I add the argument auto_scheduler_rewritten_layout="" in nali’s conv2d_winograd_nhwc function, It seems the tuning process works well, but the total estimated total latency is not correct. Below is the tunning log, In my manual kernel, the vgg model run about 57ms.


----------------------------------------------------------------------
------------------------------  [ Task Scheduler ]
----------------------------------------------------------------------
|  ID  | Latency (ms) | Speed (GFLOPS) | Trials |
-------------------------------------------------
|    0 |        0.024 |       43538.19 |     64 |
|    1 |        0.165 |       16314.41 |   1216 |
|    2 |        5.740 |          86.44 |    256 |
-------------------------------------------------
Estimated total latency: 8.730 ms       Trials: 0       Used time : 8 s Next ID: 2
----------------------------------------------------------------------
------------------------------  [ Search ]

@jcf94 Hi, below is my test code, could you help me to check if there has any wrong configure, thanks for your help.

import numpy as np

import tvm
from tvm import relay, auto_scheduler, transform
import tvm.relay.testing
from tvm.contrib import graph_runtime, graph_executor
from tvm.auto_scheduler.utils import request_remote
import os, time

#################################################################
tracker_port = 9190
device_key = "android"

network = "vgg_float.tflite"
input_tensor = "input"
input_shape = (1, 256, 256, 1)
input_dtype = "float32"
local_memory = 16

# Set this to True if you use ndk tools for cross compiling
use_ndk = True
os.environ["TVM_NDK_CC"] = "/opt/android-ndk/android-ndk-r21e/android-toolchain-arm64/bin/aarch64-linux-android-g++"
target_host = tvm.target.Target("llvm -mtriple=aarch64-linux-android")
target = tvm.target.Target("opencl -device=mali")
log_file = "%s-NHWC-%s-%s-%d.json" % (network,  target.kind.name, input_dtype, local_memory)

#################################################################
# Define a Network
def get_network(name):
    tflite_model_file = os.path.join("/home/shaohaizhu/workspace/aibenchmark", name)
    tflite_model_buf  = open(tflite_model_file, "rb").read()
    try:
        import tflite
        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
    except AttributeError:
        import tflite.Model
        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
    
    fp16 = True if input_dtype == "float16"  else False
    mod, params = relay.frontend.from_tflite(
        tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}, is_fp16 = fp16
    )

    return mod, params

# Extract Search Tasks
print("before get_network...")
mod, params = get_network(network)
print("after get_network...")

#################################################################
# get the hardware parameters from remote device
def get_hw_params():    
    print("begin get hardware params......")
    remote = request_remote(device_key, "0.0.0.0", tracker_port)
    ctx = remote.cl()
    max_vthread_extent = int(ctx.warp_size / 4) if int(ctx.warp_size / 4) > 1 else ctx.warp_size
    warp_size = ctx.warp_size
    max_shared_memory_per_block = ctx.max_shared_memory_per_block
    max_threads_per_block = ctx.max_threads_per_block
    # There is no explicit local memory limition, so we can use INT32_MAX to disalbe the check on local_memory.
    # max_local_memory_per_block = 2147483647 # INT32_MAX
    max_local_memory_per_block = local_memory    
    hardware_params = auto_scheduler.HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, max_threads_per_block, max_vthread_extent, warp_size)    
    print("end get hardware params......")
    return hardware_params
hardware_params = get_hw_params()

tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target, target_host, hardware_params)

for idx, task in enumerate(tasks):
    print("========== Task %d  (workload key: %s) ==========" % (idx, task.workload_key))
    print(task.compute_dag)

#################################################################
# Tuning
def tune():
    print("Begin tuning...")
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=20000, # change this to 20000 to achieve the best performance
        builder=auto_scheduler.LocalBuilder(build_func="ndk" if use_ndk else "default"),
        runner=auto_scheduler.RPCRunner(device_key, host="0.0.0.0", port=tracker_port, repeat=3, timeout=50),
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

    tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=log_file)
    tuner.tune(tune_option)

    # Compile the whole network
    print("Compile to build lib...")

    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
            lib = relay.build(mod, target=target, target_host=target_host, params=params)
    
    print("Done build.")  
    return lib

#################################################################
# Evaluate
def evaluate(lib):  
    print("Begin evaluate...")
    remote = request_remote(device_key, "0.0.0.0", tracker_port, timeout=1000)
    ctx = remote.cl()
    print("=============== Get Remote and Context ===============")

    from tvm.contrib import utils, ndk
    temp = utils.tempdir()
    filename = "deploy_lib_%s_%s.so" % (network, target.kind.name)
    path_lib = temp.relpath(filename)
    print("Begin ndk.create_shared...")
    lib.export_library(path_lib, ndk.create_shared)
    print("End   ndk.create_shared...")
    remote.upload(path_lib)
    loaded_lib = remote.load_module(filename)
    print("=============== Done load remote module ===============")
    module = graph_executor.GraphModule(loaded_lib["default"](ctx))
    data = (np.random.uniform(size=input_shape)).astype(input_dtype)
    data_tvm = tvm.nd.array(data)
    module.set_input("input", data_tvm)

    # Evaluate    
    iter = 100
    print("Evaluate inference time cost ...")
    ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=iter, min_repeat_ms=0)
    prof_res = np.array(ftimer().results) * 1e3  # convert to millisecond 
    print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))    

    #out_tvm = module.get_output(0)
    #true_out = (np.loadtxt("./mv2_output.txt", dtype=np.float32)).astype(input_dtype)
    #true_out.shape = output_shape
    #print("Evaluate inference correctness...")
    #np.testing.assert_allclose(out_tvm.asnumpy(), true_out, rtol=1e-3)
    #print("Done Evaluate inference correctness")

def run():
    lib = tune()
    evaluate(lib)

run()

It seems the offical tvm version has add disabled_pass={"AutoSchedulerLayoutRewrite"} defaultly. I found it is located in python/tvm/auto_scheduler/relay_integration.py:call_all_topi_funcs. I have tried to remove the code disabled_pass={"AutoSchedulerLayoutRewrite"}, then auto_scheduler.extract_tasks return 0 tasks.

I don’t have mali devices now. If report too many arguments to unpack, could you help to print what is the value to unpack? I think it is the start point to debug.

@merrymercy @comaniac @jcf94 @FrozenGene Hi, I have solved this problem by add some snip code in python/tvm/topi/mali/conv2d.py and python/tvm/relay/op/strategy/mali.py. The code list as below:

# python/tvm/topi/mali/conv2d.py
# add argument ' auto_scheduler_rewritten_layout = "" ' in conv2d_winograd_nhwc_mali
@conv2d_winograd_nhwc.register(["mali"])
def conv2d_winograd_nhwc_mali(
    data, weight, strides, padding, dilation, out_dtype, pre_computed=False, auto_scheduler_rewritten_layout=""
):
    """Conv2D Winograd in NHWC layout.
    This is a clean version to be used by the auto-scheduler for mali.
    """
    tile_size = _pick_tile_size(data, weight, layout="NHWC")
    return _conv2d_winograd_nhwc_impl(
        data, weight, strides, padding, dilation, out_dtype, tile_size, pre_computed, auto_scheduler_rewritten_layout
    )

# python/tvm/relay/op/strategy/mali.py
# add need_auto_scheduler_layout=True in conv2d_strategy_mali and conv2d_winograd_without_weight_transfrom_strategy_mali
@conv2d_strategy.register("mali")
def conv2d_strategy_mali(attrs, inputs, out_type, target):
  #...
  if is_winograd_applicable:
                strategy.add_implementation(
                    wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True),
                    naive_schedule,  # this implementation should never be picked by autotvm
                    name="conv2d_nhwc.winograd",
                    plevel=15,
                )
  #...

@conv2d_winograd_without_weight_transfrom_strategy.register("mali")
def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_type, target):
  #...
  if not is_auto_scheduler_enabled():
            raise RuntimeError(
                "Winograd conv2d NHWC is not enabled for mali without auto_scheduler."
            )
        strategy.add_implementation(
            wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform, need_auto_scheduler_layout=True),
            naive_schedule,  # this implementation should never be picked by autotvm
            name="conv2d_nhwc_winograd_without_weight_transform",
            plevel=15,
        )

could u send a pr to upstream?

Hi, could you give some ideas about how to set max_local_memory_per_block a suitable value for mali gpu?