In function SampleInitPopulation
there is a for loop as code block below in file sketch_policy.cc
.
while (static_cast<int>(out_states.size()) < sample_init_min_pop_) {
std::vector<State> temp_states(population);
// Sample a batch of states randomly
support::parallel_for(0, population, [this, &temp_states, &sketches, &rand_gens](int index) {
// Randomly choose a sketch
State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
// Apply random annotation rules one by one
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
}
if (valid) {
temp_states[index] = std::move(tmp_s);
}
});
// Filter out the states that were failed to apply initial rules
Array<State> cand_states;
for (auto tmp_s : temp_states) {
if (tmp_s.defined()) {
cand_states.push_back(std::move(tmp_s));
} else {
fail_ct++;
}
}
unchange_cnt++;
if (!cand_states.empty()) {
// Run the cost model to make filter out states that failed to extract features.
// This may happen due to illegal schedules or the schedules that uses too much
// memory on GPU.
std::vector<float> pop_scores;
pop_scores.reserve(cand_states.size());
cand_states = search_task->compute_dag.InferBound(cand_states);
PruneInvalidState(search_task, &cand_states);
program_cost_model->Predict(search_task, cand_states, &pop_scores);
for (size_t i = 0; i < cand_states.size(); i++) {
const auto state_str = cand_states[i].ToStr();
if (pop_scores[i] > -1e10 && explored_state_strs.count(state_str) == 0) {
explored_state_strs.insert(state_str);
out_states.push_back(std::move(cand_states[i]));
unchange_cnt = 0; // Reset the counter once we found a valid state
} else {
fail_ct++;
}
}
}
if (iter % 5 == 0) {
double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
.count();
StdCout(verbose) << "Sample Iter: " << iter << std::fixed << std::setprecision(4)
<< "\t#Pop: " << out_states.size() << "\t#Target: " << sample_init_min_pop_
<< "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed
<< std::setprecision(2) << duration << std::endl;
}
if (unchange_cnt == 5) {
// Reduce the target size to avoid too-long time in this phase if no valid state was found
// in the past iterations
if (sample_init_min_pop_ > 1) {
sample_init_min_pop_ /= 2;
StdCout(verbose) << "#Target has been reduced to " << sample_init_min_pop_
<< " due to too many failures or duplications" << std::endl;
}
unchange_cnt = 0;
}
iter++;
}
In my case, getting the feature score of cand_states
always results in an error. So the value of out_states.size()
is always 0
, and the value of sample_init_min_pop_
will be reduced every 5 iterations and will be 1
finally.
As the condition of this loop is (static_cast<int>(out_states.size()) < sample_init_min_pop_)
, so the loop cannot be broken.
Is this a bug?