[AutoScheduler] Cannot break Sample Init Population loop

In function SampleInitPopulation there is a for loop as code block below in file sketch_policy.cc.

  while (static_cast<int>(out_states.size()) < sample_init_min_pop_) {
    std::vector<State> temp_states(population);

    // Sample a batch of states randomly
    support::parallel_for(0, population, [this, &temp_states, &sketches, &rand_gens](int index) {
      // Randomly choose a sketch
      State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
      // Apply random annotation rules one by one
      bool valid = true;
      for (const auto& rule : init_rules) {
        if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
            PopulationGenerationRule::ResultKind::kInvalid) {
          valid = false;
          break;
        }
      }
      if (valid) {
        temp_states[index] = std::move(tmp_s);
      }
    });

    // Filter out the states that were failed to apply initial rules
    Array<State> cand_states;
    for (auto tmp_s : temp_states) {
      if (tmp_s.defined()) {
        cand_states.push_back(std::move(tmp_s));
      } else {
        fail_ct++;
      }
    }

    unchange_cnt++;
    if (!cand_states.empty()) {
      // Run the cost model to make filter out states that failed to extract features.
      // This may happen due to illegal schedules or the schedules that uses too much
      // memory on GPU.
      std::vector<float> pop_scores;
      pop_scores.reserve(cand_states.size());
      cand_states = search_task->compute_dag.InferBound(cand_states);
      PruneInvalidState(search_task, &cand_states);
      program_cost_model->Predict(search_task, cand_states, &pop_scores);

      for (size_t i = 0; i < cand_states.size(); i++) {
        const auto state_str = cand_states[i].ToStr();
        if (pop_scores[i] > -1e10 && explored_state_strs.count(state_str) == 0) {
          explored_state_strs.insert(state_str);
          out_states.push_back(std::move(cand_states[i]));
          unchange_cnt = 0;  // Reset the counter once we found a valid state
        } else {
          fail_ct++;
        }
      }
    }

    if (iter % 5 == 0) {
      double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
                            std::chrono::high_resolution_clock::now() - tic_begin)
                            .count();
      StdCout(verbose) << "Sample Iter: " << iter << std::fixed << std::setprecision(4)
                       << "\t#Pop: " << out_states.size() << "\t#Target: " << sample_init_min_pop_
                       << "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed
                       << std::setprecision(2) << duration << std::endl;
    }

    if (unchange_cnt == 5) {
      // Reduce the target size to avoid too-long time in this phase if no valid state was found
      // in the past iterations
      if (sample_init_min_pop_ > 1) {
        sample_init_min_pop_ /= 2;
        StdCout(verbose) << "#Target has been reduced to " << sample_init_min_pop_
                         << " due to too many failures or duplications" << std::endl;
      }
      unchange_cnt = 0;
    }
    iter++;
  }

In my case, getting the feature score of cand_states always results in an error. So the value of out_states.size() is always 0, and the value of sample_init_min_pop_ will be reduced every 5 iterations and will be 1 finally. As the condition of this loop is (static_cast<int>(out_states.size()) < sample_init_min_pop_), so the loop cannot be broken. Is this a bug?

What’s your compute_dag that causes this problem? If it is a bug, you are welcome to contribute a patch.

Here is the DAG in my case.

========== Task 1  (workload key: ["33a10bb123ba1d747c97e5296eee59ed", 256, 1, 256, 1, 16, 16, 256, 1, 1, 256]) ==========
compile_engine_const() = 0
placeholder = PLACEHOLDER [256, 1]
T_repeat(ax0, ax1) = placeholder[ax0, floordiv(ax1, 2)]
T_take(ax0) = T_repeat[ax0, (((compile_engine_const[] % 2) + 2) % 2)]
compile_engine_const() = 1
T_take(ax0) = T_repeat[ax0, (((compile_engine_const[] % 2) + 2) % 2)]
T_reshape(ax0, ax1) = T_take[floormod((ax0 + ax1), 256)]
T_subtract(ax0, ax1) = (T_take[ax1] - T_reshape[ax0, 0])
compile_engine_const() = 0.0625f
T_less(ax0, ax1) = (T_subtract[ax0, ax1] < compile_engine_const[])
T_cast(ax0, ax1) = T_less[ax0, ax1]
compile_engine_const() = -0.0625f
T_greater(ax0, ax1) = (T_subtract[ax0, ax1] > compile_engine_const[])
T_cast(ax0, ax1) = T_greater[ax0, ax1]
T_logical_and(ax0, ax1) = (T_cast[ax0, ax1] && T_cast[ax0, ax1])
T_cast(ax0, ax1) = T_logical_and[ax0, ax1]
compile_engine_const() = 0
placeholder = PLACEHOLDER [256, 1]
T_repeat(ax0, ax1) = placeholder[ax0, floordiv(ax1, 2)]
T_take(ax0) = T_repeat[ax0, (((compile_engine_const[] % 2) + 2) % 2)]
compile_engine_const() = 1
T_take(ax0) = T_repeat[ax0, (((compile_engine_const[] % 2) + 2) % 2)]
T_reshape(ax0, ax1) = T_take[floormod((ax0 + ax1), 256)]
T_subtract(ax0, ax1) = (T_take[ax1] - T_reshape[ax0, 0])
compile_engine_const() = 0.0625f
T_less(ax0, ax1) = (T_subtract[ax0, ax1] < compile_engine_const[])
T_cast(ax0, ax1) = T_less[ax0, ax1]
T_logical_and(ax0, ax1) = (T_cast[ax0, ax1] && T_cast[ax0, ax1])
T_cast(ax0, ax1) = T_logical_and[ax0, ax1]
compile_engine_const() = -0.0625f
T_greater(ax0, ax1) = (T_subtract[ax0, ax1] > compile_engine_const[])
T_cast(ax0, ax1) = T_greater[ax0, ax1]
T_logical_and(ax0, ax1) = (T_cast[ax0, ax1] && T_cast[ax0, ax1])
T_cast(ax0, ax1) = T_logical_and[ax0, ax1]
placeholder = PLACEHOLDER [16, 16]
T_cast(ax0, ax1) = float32(placeholder[ax0, ax1])
T_reshape(ax0, ax1) = T_cast[floormod(floordiv((ax0 + ax1), 16), 16), floormod((ax0 + ax1), 16)]
T_repeat(ax0, ax1) = T_reshape[ax0, floordiv(ax1, 256)]
T_cast(ax0, ax1) = bool(T_repeat[ax0, ax1])
T_reshape(ax0, ax1) = T_cast[floormod(floordiv(((ax0*256) + ax1), 16), 16), floormod(((ax0*256) + ax1), 16)]
T_repeat(ax0, ax1) = T_reshape[floordiv(ax0, 256), ax1]
T_cast(ax0, ax1) = bool(T_repeat[ax0, ax1])
T_logical_and(ax0, ax1) = (T_cast[ax0, ax1] && T_cast[ax0, ax1])
T_cast(ax0, ax1) = T_logical_and[ax0, ax1]
T_logical_and(ax0, ax1) = (T_cast[ax0, ax1] && T_cast[ax0, ax1])
placeholder = PLACEHOLDER [256, 1]
T_repeat(ax0, ax1) = placeholder[ax0, floordiv(ax1, 256)]
placeholder = PLACEHOLDER [1]
T_where(ax0, ax1) = select((int32(T_logical_and[ax0, ax1]) != 0), T_repeat[ax0, ax1], placeholder[0])
T_where_red_temp(ax0).v0 select((argmax_lhs_1 >= argmax_rhs_1), argmax_lhs_0, argmax_rhs_0)= (k1,T_where[ax0, k1])
T_where_red_temp(ax0).v1 select((argmax_lhs_1 >= argmax_rhs_1), argmax_lhs_1, argmax_rhs_1)= (k1,T_where[ax0, k1])
T_where_red(ax0) = T_where_red_temp.v0[ax0]

I excluded this task from the tuning task to continue my tuning work at last. I think it would be more make sense to skip the task than to block in the loop if auto-Schedule cannot search for a valid schedule.