Assertion triggered when auto-scheduling

I’m trying to auto-schedule a simple operator graph as follows (code below adapted from the auto-scheduling) tutorial.

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi

@auto_scheduler.register_workload
def gemm_fused_layer(M, K, N):
    w1 = te.placeholder((M, K), name="w1")
    w2 = te.placeholder((M, K), name="w2")

    m1 = te.placeholder((N, K), name="m1")
    m2 = te.placeholder((N, K), name="m2")

    k = te.reduce_axis((0, K), name="k")
    mm1 = te.compute((M, N), lambda i, j: te.sum(w1[i, k]*m1[j, k], axis=k), name="mm1")

    k = te.reduce_axis((0, K), name="k")
    mm2 = te.compute((M, N), lambda i, j: te.sum(w2[i, k]*m2[j, k], axis=k), name="mm2")

    o = te.compute((M, N), lambda i, j: mm1[i, j]*mm2[i, j], name="o")

    return [w1, w2, m1, m2, mm1, mm2, o]



######################################################################
# Create the search task
target = tvm.target.Target("cuda")

M, K, N = 512, 512, 512
task = auto_scheduler.SearchTask(func=gemm_fused_layer, args=(M, K, N), target=target)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

log_file = "gemm_fused.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=10,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

task.tune(tune_option)
sch, args = task.apply_best(log_file)

# Kill the measurement process
del measure_ctx

func = tvm.build(sch, args, target)
# print(func.imported_modules[0].get_source())

# Evaluate execution time
device = tvm.gpu()
w1_np = np.random.uniform(size=(M, K)).astype(np.float32)
w2_np = np.random.uniform(size=(M, K)).astype(np.float32)
m1_np = np.random.uniform(size=(N, K)).astype(np.float32)
m2_np = np.random.uniform(size=(N, K)).astype(np.float32)
o_np = np.random.uniform(size=(M, N)).astype(np.float32)

w1_tvm = tvm.nd.empty(w1_np.shape, device=device)
w2_tvm = tvm.nd.empty(w2_np.shape, device=device)
m1_tvm = tvm.nd.empty(m1_np.shape, device=device)
m2_tvm = tvm.nd.empty(m2_np.shape, device=device)
o_tvm = tvm.nd.empty(o_np.shape, device=device)

evaluator = func.time_evaluator(func.entry_name, device, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(w1_tvm, w2_tvm, m1_tvm, m2_tvm, o_tvm).results) * 1000)
)

This throws up the following assertion:

Traceback (most recent call last):
  File "tune_gemm_fused_cuda.py", line 47, in <module>
    task.tune(tune_option)
  File "/home/ppf/rnn_compilers/lowering/tvm/tvm/python/tvm/auto_scheduler/search_task.py", line 452, in tune
    _ffi_api.AutoSchedule(search_policy, tuning_options)
  File "/home/ppf/rnn_compilers/lowering/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::runtime::ObjectRef, void> (tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}>(tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  5: tvm::auto_scheduler::AutoSchedule(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)
  4: tvm::auto_scheduler::SketchPolicyNode::Search(int, int, int, tvm::auto_scheduler::ProgramMeasurer)
  3: tvm::auto_scheduler::SketchPolicyNode::SearchOneRound(int, tvm::runtime::Array<tvm::auto_scheduler::State, void>*)
  2: tvm::auto_scheduler::SketchPolicyNode::GenerateSketches()
  1: tvm::auto_scheduler::RuleMultiLevelTilingWithFusion::Apply(tvm::auto_scheduler::SketchPolicyNode const&, tvm::auto_scheduler::State const&, int) const
  0: tvm::auto_scheduler::FollowTiling(tvm::auto_scheduler::State const&, int, std::vector<int, std::allocator<int> > const&, int)
  File "../src/auto_scheduler/search_policy/utils.cc", line 289
TVMError:
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------

  Check failed: state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt == split_step_ids.size() (8 vs. 2) :

When I simplify the operator and remove the final addition, the auto-scheduling works fine. Is there anything I’m doing wrong?

At the first glance, maybe it’s the issue of creating a sketch to diamond compute, but I forget the details…

cc @jcf94 @merrymercy

Is there any update on this? Thanks!

Hi! I ran into the same problem. My code is attached below (I just modified the tutorial on auto-scheduling conv2d):

import os

import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
from tvm.topi.testing import conv2d_nchw_python

@auto_scheduler.register_workload
def my_func(N, H, W, CO, CI, KH, KW, stride, padding):
    data = te.placeholder((N, CI, H, W), name="data")
    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype="float32")
    kernel2 = te.placeholder((CO, CI, KH, KW), name="kernel2")
    conv2 = topi.nn.conv2d_nchw(data, kernel2, stride, padding, dilation=1, out_dtype="float32")
    out = topi.add(conv, conv2)
    return [data, kernel, kernel2, out]

target = tvm.target.Target("cuda")

N, H, W, CI, CO, KH, KW, strides, padding = 1, 224, 224, 3, 64, 3, 3, (1, 1), (1, 1)
task = auto_scheduler.SearchTask(
    func=my_func, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
)

# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

log_file = "my_func.json"
measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10)
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=20000,  # change this to 1000 to achieve the best performance
    runner=measure_ctx.runner,
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

# Kill the measurement process
del measure_ctx

#print("Lowered TIR:")
#print(tvm.lower(sch, args, simple_mode=True))

func = tvm.build(sch, args, target)

# Check correctness
data_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
weight_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
weight2_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
conv_np = conv2d_nchw_python(data_np, weight_np, strides, padding)
conv2_np = conv2d_nchw_python(data_np, weight2_np, strides, padding)
out_np = np.add(conv_np + conv2_np)

ctx = tvm.gpu()
data_tvm = tvm.nd.array(data_np, ctx=ctx)
weight_tvm = tvm.nd.array(weight_np, ctx=ctx)
weight2_tvm = tvm.nd.array(weight2_np, ctx=ctx)
out_tvm = tvm.nd.empty(out_np.shape, ctx=ctx)
func(data_tvm, weight_tvm, weight2_tvm, out_tvm)

# Check results
np.testing.assert_allclose(out_np, out_tvm.asnumpy(), rtol=1e-3)

# Evaluate execution time
evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500)
print(
    "Execution time of this operator: %.3f ms"
    % (np.median(evaluator(data_tvm, weight_tvm, weight2_tvm, out_tvm).results) * 1000)
)

Here is the output including the error message and some intermediate outputs that I added into the auto-scheduler source code:

Computational DAG:
data = PLACEHOLDER [1, 3, 224, 224]
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 < 225)) && (i3 >= 1)) && (i3 < 225)), data[i0, i1, (i2 - 1), (i3 - 1)], 0f)
kernel = PLACEHOLDER [64, 3, 3, 3]
compute(nn, ff, yy, xx) += (pad_temp[nn, rc, (yy + ry), (xx + rx)]*kernel[ff, rc, ry, rx])
pad_temp(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 < 225)) && (i3 >= 1)) && (i3 < 225)), data[i0, i1, (i2 - 1), (i3 - 1)], 0f)
kernel2 = PLACEHOLDER [64, 3, 3, 3]
compute(nn, ff, yy, xx) += (pad_temp[nn, rc, (yy + ry), (xx + rx)]*kernel2[ff, rc, ry, rx])
T_add(ax0, ax1, ax2, ax3) = (compute[ax0, ax1, ax2, ax3] + compute[ax0, ax1, ax2, ax3])

----------------------------------------------------------------------
------------------------------  [ Search ]
----------------------------------------------------------------------
Initial state is Placeholder: data, kernel, kernel2
for i1 (0,3)
  for i2 (0,226)
    for i3 (0,226)
      pad_temp = ...
for ff (0,64)
  for yy (0,224)
    for xx (0,224)
      for rc (0,3)
        for ry (0,3)
          for rx (0,3)
            compute = ...
for i1 (0,3)
  for i2 (0,226)
    for i3 (0,226)
      pad_temp = ...
for ff (0,64)
  for yy (0,224)
    for xx (0,224)
      for rc (0,3)
        for ry (0,3)
          for rx (0,3)
            compute = ...
for ax1 (0,64)
  for ax2 (0,224)
    for ax3 (0,224)
      T_add = ...

Applied to stage 7 Sketch Generaion rule RuleSkipStage. The index of the next stage is 6
Applied to stage 6 Sketch Generaion rule RuleMultiLevelTilingWithFusion. The index of the next stage is 5
Applied to stage 5 Sketch Generaion rule RuleAddCacheRead. The index of the next stage is 5
Applied to stage 5 Sketch Generaion rule RuleSkipStage. The index of the next stage is 4
Applied to stage 4 Sketch Generaion rule RuleAddCacheRead. The index of the next stage is 4
Applied to stage 4 Sketch Generaion rule RuleAlwaysInline. The index of the next stage is 3
Get devices for measurement successfully!
Traceback (most recent call last):
  File "ansor_tune_sub116before.py", line 128, in <module>
    task.tune(tune_option)
  File "/home1/qinyiluo/tvm/python/tvm/auto_scheduler/search_task.py", line 445, in tune
    _ffi_api.AutoSchedule(search_policy, tuning_options)
  File "/home1/qinyiluo/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home1/qinyiluo/tvm/build/libtvm.so(TVMFuncCall+0x5b) [0x7f341e92998b]
  [bt] (7) /home1/qinyiluo/tvm/build/libtvm.so(+0x5a21c2) [0x7f341db381c2]
  [bt] (6) /home1/qinyiluo/tvm/build/libtvm.so(tvm::auto_scheduler::AutoSchedule(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)+0x104) [0x7f341db37844]
  [bt] (5) /home1/qinyiluo/tvm/build/libtvm.so(tvm::auto_scheduler::SketchPolicyNode::Search(int, int, int, tvm::auto_scheduler::ProgramMeasurer)+0x266) [0x7f341dbe9636]
  [bt] (4) /home1/qinyiluo/tvm/build/libtvm.so(tvm::auto_scheduler::SketchPolicyNode::SearchOneRound(int, tvm::runtime::Array<tvm::auto_scheduler::State, void>*)+0x2fd) [0x7f341dbe8f0d]
  [bt] (3) /home1/qinyiluo/tvm/build/libtvm.so(tvm::auto_scheduler::SketchPolicyNode::GenerateSketches()+0x530) [0x7f341dbe4130]
  [bt] (2) /home1/qinyiluo/tvm/build/libtvm.so(tvm::auto_scheduler::RuleMultiLevelTilingWithFusion::Apply(tvm::auto_scheduler::SketchPolicyNode const&, tvm::auto_scheduler::State const&, int) const+0x1de) [0x7f341dbf586e]
  [bt] (1) /home1/qinyiluo/tvm/build/libtvm.so(tvm::auto_scheduler::FollowTiling(tvm::auto_scheduler::State const&, int, std::vector<int, std::allocator<int> > const&, int)+0x608) [0x7f341dc0b088]
  [bt] (0) /home1/qinyiluo/tvm/build/libtvm.so(+0x672922) [0x7f341dc08922]
  File "/home1/qinyiluo/tvm/src/auto_scheduler/search_policy/utils.cc", line 289
TVMError: 
---------------------------------------------------------------
An internal invariant was violated during the execution of TVM.
Please read TVM's error reporting guidelines.
More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793.
---------------------------------------------------------------

  Check failed: state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt == split_step_ids.size() (16 vs. 4) : 
Process Process-1:
Traceback (most recent call last):
  File "/home1/qinyiluo/tvm/python/tvm/rpc/base.py", line 170, in connect_with_retry
    sock.connect(addr)
ConnectionRefusedError: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home1/qinyiluo/tvm/python/tvm/rpc/server.py", line 230, in _listen_loop
    raise exc
  File "/home1/qinyiluo/tvm/python/tvm/rpc/server.py", line 211, in _listen_loop
    tracker_conn = base.connect_with_retry(tracker_addr)
  File "/home1/qinyiluo/tvm/python/tvm/rpc/base.py", line 177, in connect_with_retry
    raise RuntimeError("Failed to connect to server %s" % str(addr))
RuntimeError: Failed to connect to server ('0.0.0.0', 9000)

The error at tvm/src/auto_scheduler/search_policy/utils.cc line 289 is inside a function called FollowTiling, which was called by RuleMultiLevelTilingWithFusion::Apply. So it seems like the error happened when the auto scheduler is trying to apply RuleMultiLevelTilingWithFusion to the 1st convolution op at stage 3.

Any help or advice on this is greatly appreciated. Thanks a lot!

Are there any plans to fix this bug?

It seems this is a bug in the auto-scheduler. We didn’t test these cases because the graph partition algorithm in relay (FuseOps) will never generate such patterns. As a walk around, you have to split them into two subgraphs, which is what the relay pass does.