AutoScheduler failed to find a valid schedule

Hello:
I am trying to generate source code of an operator which involves conv2d_nchw and gradient. I tried to achieve this with AutoScheduler. However, the search period will fail when the operator has a relative lager shape, and when the task.apply_best(log_file) was run,it will tell me that there is no valid schedule in the json file.

To be exact, the program works well when the layer looks like:

@auto_scheduler.register_workload
def conv2d_layer():
    x = tvm.te.placeholder((1,64,7,7), name='x')
    w1 = tvm.te.placeholder((64,64,3,3), name='w1')
    z1=topi.nn.conv2d_nchw(x,w1,(1,1),(1,1),dilation=1,out_dtype="float32")
    [dw1] = tvm.te.gradient(z1, [w1])
    return [x,w1,dw1]

but will fail when it looks like:

@auto_scheduler.register_workload
def conv2d_layer():
    x = tvm.te.placeholder((1,128,7,7), name='x')
    w1 = tvm.te.placeholder((128,128,3,3), name='w1')
    z1=topi.nn.conv2d_nchw(x,w1,(1,1),(1,1),dilation=1,out_dtype="float32")
    [dw1] = tvm.te.gradient(z1, [w1])
    return [x,w1,dw1]

the whole program looks like this.

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm.topi.testing import conv2d_nchw_python

@auto_scheduler.register_workload
def conv2d_layer():
    x = tvm.te.placeholder((1,128,7,7), name='x')
    w1 = tvm.te.placeholder((128,128,3,3), name='w1')
    z1=topi.nn.conv2d_nchw(x,w1,(1,1),(1,1),dilation=1,out_dtype="float32")
    [dw1] = tvm.te.gradient(z1, [w1])
    return [x,w1,dw1]

target = tvm.target.Target("cuda")

task = auto_scheduler.SearchTask(
    func=conv2d_layer, args=(), target=target
)


print("Computational DAG:")
print(task.compute_dag)

log_file = "conv2d.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=1000,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)
task.tune(tune_option)

sch, args = task.apply_best(log_file)

del measure_ctx

print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

print("CUDA source code:")
print(task.print_best(log_file, print_mode="cuda"))

The last part of the output is below:

==================================================
No: 1000        GFLOPS: 0.00 / 0.00     results: MeasureResult(error_type:Runtim                                                                                                                              eDeviceError, error_msg:Traceback (most recent call last):
  File "/root/nnfusion/artifacts/.deps/tvm-0.7/python/tvm/auto_scheduler/measure                                                                                                                              .py", line 1120, in _rpc_run
    random_fill(empty_array)
  File "/root/nnfusion/artifacts/.deps/tvm-0.7/python/tvm/_ffi/_ctypes/packed_fu                                                                                                                              nc.
...
----------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (code == RPCCode::kReturn) is false: code=1
, all_cost:13.82, Tstamp:1631297770.85)
==================================================
Placeholder: x, k
blockIdx.x nn.0@ff.0@yy.0@xx.0@ (0,14)
  threadIdx.x nn.2@ff.2@yy.2@xx.2@ (0,8)
    for rc.0 (0,8)
      for ax0@ax1@ax2@ax3@.0.0 (0,1152)
        threadIdx.x ax0@ax1@ax2@ax3@.0.1 (0,8)
          k.shared = ...
      for ax0@ax1@ax2@ax3@.0.0 (0,18)
        threadIdx.x ax0@ax1@ax2@ax3@.0.1 (0,8)
          vectorize ax0@ax1@ax2@ax3@.1 (0,3)
            pad_temp.shared = ...
      for rc.1 (0,16)
        for xx_c.3 (0,7)
          for ry.2 (0,3)
            for rx.2 (0,3)
              for ff_c.4 (0,8)
                compute.local = ...
    for ff.3 (0,8)
      for xx.3 (0,7)
        compute = ...
blockIdx.x ax0.0@ax1.0@ax2.0@ax3.0@ax4.0@ax5.0@ax6.0@ax7.0@ (0,1024)
  vthread ax0.1@ax1.1@ax2.1@ax3.1@ax4.1@ax5.1@ax6.1@ax7.1@ (0,14)
    threadIdx.x ax0.2@ax1.2@ax2.2@ax3.2@ax4.2@ax5.2@ax6.2@ax7.2@ (0,288)
      compute.k.grad.local auto_unroll: 1024
      for n0_n1_k2_shifted_shifted.0 (0,7)
        for n1_n2_k3_shifted_shifted.0 (0,7)
          for ax0@ax1@ax2@ax3@ax4@ax5@ax6@ax7@.0.0 (0,6)
            threadIdx.x ax0@ax1@ax2@ax3@ax4@ax5@ax6@ax7@.0.1 (0,288)
              compute.compute.grad.shared = ...
          for ax0@ax1@ax2@ax3@.0.0 (0,2)
            threadIdx.x ax0@ax1@ax2@ax3@.0.1 (0,288)
              pad_temp.d.shared = ...
          for ax3_c.3 (0,7)
            for ax4_c.3 (0,4)
              for ax4_c.4 (0,4)
                for ax5_c.4 (0,2)
                  compute.k.grad.local = ...
      for ax3.3 (0,7)
        for ax4.3 (0,16)
          for ax5.3 (0,2)
            compute.k.grad = ...

[18:16:10] /root/nnfusion/artifacts/.deps/tvm-0.7/src/auto_scheduler/measure.cc:                                                                                                                              299: Warning: Too many errors happened during tuning. Switching to debug mode.

Time elapsed for measurement: 426.86 s
----------------------------------------------------------------------
------------------------------  [ Done ]
----------------------------------------------------------------------
No valid state found in this search round. Check if it has traversed all of the                                                                                                                               search space.
/root/nnfusion/artifacts/.deps/anaconda3/lib/python3.6/site-packages/xgboost/tra                                                                                                                              ining.py:17: UserWarning: Old style callback is deprecated.  See: https://xgboos                                                                                                                              t.readthedocs.io/en/latest/python/callbacks.html
  warnings.warn(f'Old style callback is deprecated.  See: {link}', UserWarning)
MeasureInput with old format workload key ["conv2d_layer"] should be updated usi                                                                                                                              ng the script from https://github.com/apache/tvm/pull/7317.
MeasureInput with old format workload key ["dense_layer"] should be updated usin                                                                                                                              g the script from https://github.com/apache/tvm/pull/7317.



Traceback (most recent call last):
  File "conv_layer_tuning_grad.py", line 38, in <module>
    sch, args = task.apply_best(log_file)
  File "/root/nnfusion/artifacts/.deps/tvm-0.7/python/tvm/auto_scheduler/search_                                                                                                                              task.py", line 522, in apply_best
    "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log                                                                                                                              _file)
RuntimeError: Cannot find any valid schedule for ["conv2d_layer"] in file conv2d                                                                                                                              .json

I am wandering how to make it work.
Can anyone help me?

Do anyone have any suggestion? I am hoping to get your advise! Thanks a lot!

hi, @comaniac , thanks for your reply in github issue. I check my cmake/config.cmake , and the difference between it and the tvm version(commit id: 88b2be8e84199addbbb73dbc942248191420e14c) includes:

my version: set(USE_CUDA ON) vs tvm version: set(USE_CUDA OFF)
my version: set(USE_LLVM ON) vs tvm version: set(USE_LLVM OFF)

USE_RANDOM you mentioned is set ON by default.

And in this environment, I can search some valid schedules for other workload, like:

@auto_scheduler.register_workload
def my_workload():
    A = tvm.te.placeholder((128, 3, 224, 224), name='input0')
    W = tvm.te.placeholder((64, 3, 3, 3), name='input1')
    C = topi.nn.conv2d(A, W, (1,1), (1,1), (1,1), layout='NCHW', out_dtype=A.dtype) # (128, 64, 224, 224)
    return [A, W, C]
# one of valid records: {"i": [["[\"workload\", null, {\"input_shape\": [128, 3, 224, 224], \"filter_shape\": [64, 3, 3, 3], \"output_shape\": [128, 64, 224, 224], \"window_movement_strides\": [1, 1], \"window_dilation_strides\": [1, 1], \"padding_below_diff\": [1, 1], \"add_bias\": false, \"add_relu\": false}]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CHW", 3, "local"], ["SP", 3, 0, 128, [2, 2, 2, 2], 1], ["SP", 3, 5, 64, [1, 4, 2, 1], 1], ["SP", 3, 10, 224, [1, 2, 1, 2], 1], ["SP", 3, 15, 224, [1, 2, 1, 1], 1], ["SP", 3, 20, 3, [1, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 3, [1, 3], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 18, 22, 25, 28, 4, 9, 14, 19]], ["FSP", 4, 0, 1, 3], ["FSP", 4, 4, 2, 3], ["FSP", 4, 8, 3, 3], ["FSP", 4, 12, 4, 3], ["RE", 4, [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]], ["CA", 3, 4, 11], ["CHR", 2, "shared", [3]], ["CA", 3, 4, 14], ["CHR", 1, "shared", [4]], ["CA", 2, 5, 14], ["CI", 1], ["FU", 6, [0, 1, 2, 3]], ["AN", 6, 0, 5], ["FU", 6, [1, 2, 3, 4]], ["AN", 6, 1, 4], ["FU", 6, [2, 3, 4, 5]], ["AN", 6, 2, 6], ["FU", 4, [0, 1, 2, 3]], ["SP", 4, 0, 18, [1], 1], ["AN", 4, 1, 2], ["FFSP", 4, 0, [4, 3, 2, 1], 1, 1], ["AN", 4, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 48, [8], 1], ["AN", 2, 1, 2], ["FFSP", 2, 0, [4, 3, 2, 1], 1, 1], ["AN", 2, 1, 6], ["PR", 5, 0, "auto_unroll_max_step$512"]]]], "r": [[0.0210064], 0, 11.5793, 1636032812], "v": "v0.6"}

Hmm I have no clue if this workload works for you. You could probably use the TOPI schedule first to see if the default AutoTVM schedule works for that workload. If so, then it should be a problem in auto-scheduler; otherwise it’s more likely that the workload somehow cannot be handled, which doesn’t make sense to me tho.

btw, I checked the error code you received (code=1), and it’s kShutdown. This means the RPC sever was disconnected during the tuning, so you might also want to check the reason of that.

thanks a lot @comaniac ! I read the rpc related issues then located my bug.

Before, I tried to parallelize my searching using multiple scripts.

CUDA_VISIBLE_DEVICES=0 python3 my_tvm_searching.py ...
CUDA_VISIBLE_DEVICES=1 python3 my_tvm_searching.py ...
...

my_tvm_searching.py will launch the localrpc, and it’s not the correct usage.

    measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=100,
        runner=measure_ctx.runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        builder=tvm.auto_scheduler.LocalBuilder(timeout=100),
        verbose=2,
    )
    task.tune(tune_option)
    del measure_ctx

after fixing bug, my commands are as follow:

nohup python3 -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190 & \
CUDA_VISIBLE_DEVICES=0 nohup python3 -m tvm.exec.rpc_server --tracker 127.0.0.1:9190 --key V100 --host 0.0.0.0 --port=9091 & \
CUDA_VISIBLE_DEVICES=1 nohup python3 -m tvm.exec.rpc_server --tracker 127.0.0.1:9190 --key V100 --host 0.0.0.0 --port=9092 & 
...
python3 my_tvm_searching.py

my_tvm_searching.py:

...
        runner = tvm.auto_scheduler.RPCRunner(key="V100", host="localhost", port=9190, n_parallel=8, min_repeat_ms=300, timeout=1000)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=100,  # change this to 1000 to achieve the best performance
        runner=runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        builder=tvm.auto_scheduler.LocalBuilder(timeout=1000),
        verbose=2,
    )
    task.tune(tune_option)
...

NOTE: set n_parallel as #gpus, set timeout a proper number.

ref: https://tvm.apache.org/docs/how_to/tune_with_autoscheduler/tune_network_arm.html?highlight=parallel https://tvm.apache.org/docs/reference/api/python/auto_scheduler.html?highlight=tvm%20auto_scheduler%20rpcrunner#tvm.auto_scheduler.RPCRunner

I can not even found the log file after auto scheduler:

task = auto_scheduler.SearchTask(
        func=conv2d_aa, args=(ic, oc, in_h, in_w, k, bs), target="llvm"
    )
    print(task.compute_dag)
    log_file = 'a.json'
    tune_options = auto_scheduler.TuningOptions(
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        verbose=2,
    )

task.tune(tune_options)
    print('tuned? where is fuck log.')
    sch, args = task.apply_best(log_file)
    print(tvm.lower(sch, args, simple_mode=True))

print:

EvolutionarySearch		#s: 128	Time elapsed: 47.22
tuned? where is fuck log.
a.json does not exist!