27#include "absl/container/flat_hash_map.h"
28#include "absl/container/flat_hash_set.h"
29#include "absl/log/check.h"
30#include "absl/types/span.h"
33#include "ortools/sat/cp_model.pb.h"
36#include "ortools/sat/sat_parameters.pb.h"
45 int num_vars,
const absl::flat_hash_map<int, int64_t>&
solution) {
46 CHECK(!solution_is_loaded_);
47 CHECK(var_has_value_.empty());
48 CHECK(var_values_.empty());
49 solution_is_loaded_ =
true;
50 var_has_value_.resize(num_vars,
false);
51 var_values_.resize(num_vars, 0);
52 for (
const auto [var, value] :
solution) {
53 var_has_value_[var] =
true;
54 var_values_[var] = value;
59 if (!solution_is_loaded_)
return;
60 var_has_value_.resize(new_size,
false);
61 var_values_.resize(new_size, 0);
67 if (!solution_is_loaded_)
return;
68 if (!HasValue(
PositiveRef(literal)) && HasValue(var)) {
69 SetLiteralValue(literal, GetVarValue(var) == value);
74 int new_var, absl::Span<
const std::pair<int, int64_t>> linear,
76 if (!solution_is_loaded_)
return;
77 int64_t new_value = offset;
78 for (
const auto [var, coeff] : linear) {
79 if (!HasValue(var))
return;
80 new_value += coeff * GetVarValue(var);
82 SetVarValue(new_var, new_value);
86 absl::Span<const int> vars,
87 absl::Span<const int64_t> coeffs,
89 DCHECK_EQ(vars.size(), coeffs.size());
90 if (!solution_is_loaded_)
return;
91 int64_t new_value = offset;
92 for (
int i = 0;
i < vars.size(); ++
i) {
93 const int var = vars[
i];
94 const int64_t coeff = coeffs[
i];
95 if (!HasValue(var))
return;
96 new_value += coeff * GetVarValue(var);
98 SetVarValue(new_var, new_value);
102 if (!solution_is_loaded_)
return;
104 bool all_have_value =
true;
105 for (
const int literal : clause) {
107 if (!HasValue(var)) {
108 all_have_value =
false;
117 if (all_have_value) {
118 SetVarValue(new_var, new_value);
123 absl::Span<const int> conjunction) {
124 if (!solution_is_loaded_)
return;
126 bool all_have_value =
true;
127 for (
const int literal : conjunction) {
129 if (!HasValue(var)) {
130 all_have_value =
false;
139 if (all_have_value) {
140 SetVarValue(new_var, new_value);
145 int new_var, int64_t value,
146 absl::Span<
const std::pair<int, int64_t>> linear,
const Domain& domain) {
147 if (!solution_is_loaded_)
return;
148 int64_t linear_value = 0;
149 bool all_have_value =
true;
150 for (
const auto [var, coeff] : linear) {
151 if (!HasValue(var)) {
152 all_have_value =
false;
155 linear_value += GetVarValue(var) * coeff;
157 if (all_have_value && !domain.
Contains(linear_value)) {
158 SetVarValue(new_var, value);
163 int literal,
bool value, absl::Span<
const std::pair<int, int64_t>> linear,
177 int var,
const LinearExpressionProto& expr,
int condition_lit) {
178 if (!solution_is_loaded_)
return;
180 if (!GetLiteralValue(condition_lit))
return;
181 const std::optional<int64_t> expr_value = GetExpressionValue(expr);
182 if (expr_value.has_value()) {
183 SetVarValue(var, expr_value.value());
195 int var, absl::Span<const int> condition_lits, int64_t value_if_true,
196 int64_t value_if_false) {
197 if (!solution_is_loaded_)
return;
198 bool condition_value =
true;
199 for (
const int condition_lit : condition_lits) {
201 if (!GetLiteralValue(condition_lit)) {
202 condition_value =
false;
206 SetVarValue(var, condition_value ? value_if_true : value_if_false);
210 if (!solution_is_loaded_)
return;
212 SetLiteralValue(lit1, GetLiteralValue(lit2));
214 SetLiteralValue(lit2, GetLiteralValue(lit1));
219 if (!solution_is_loaded_)
return;
221 SetVarValue(var, domain.
ClosestValue(GetVarValue(var)));
231 const std::vector<std::pair<int, int64_t>> linear = {
233 const Domain domain =
Domain((sign1 == 1 ? 0 : -1) - (sign2 == 1 ? 0 : -1));
239 if (!solution_is_loaded_)
return;
243 if (GetLiteralValue(lit) && !GetLiteralValue(dominating_lit)) {
244 SetLiteralValue(lit,
false);
245 SetLiteralValue(dominating_lit,
true);
251 absl::Span<
const std::unique_ptr<SparsePermutation>> generators) {
252 if (!solution_is_loaded_)
return;
253 if (!HasValue(var))
return;
254 if (GetVarValue(var) ==
static_cast<int64_t
>(value))
return;
256 std::vector<int> schrier_vector;
257 std::vector<int> orbit;
260 bool found_target =
false;
262 for (
int v : orbit) {
263 if (HasValue(v) && GetVarValue(v) ==
static_cast<int64_t
>(value)) {
270 VLOG(1) <<
"Couldn't transform solution properly";
274 const std::vector<int> generator_idx =
275 TracePoint(target_var, schrier_vector, generators);
276 for (
const int i : generator_idx) {
277 PermuteVariables(*generators[
i]);
280 DCHECK(HasValue(var));
281 DCHECK_EQ(GetVarValue(var), value);
285 int ref, int64_t min_value, int64_t max_value,
286 absl::Span<
const std::pair<int, Domain>> dominating_refs) {
287 if (!solution_is_loaded_)
return;
288 const std::optional<int64_t> ref_value = GetRefValue(ref);
289 if (!ref_value.has_value())
return;
292 if (*ref_value < min_value)
return;
294 if (*ref_value <= max_value)
return;
296 const int64_t ref_value_delta = *ref_value - max_value;
298 SetRefValue(ref, *ref_value - ref_value_delta);
299 int64_t remaining_delta = ref_value_delta;
300 for (
const auto& [dominating_ref, dominating_ref_domain] : dominating_refs) {
301 const std::optional<int64_t> dominating_ref_value =
302 GetRefValue(dominating_ref);
303 if (!dominating_ref_value.has_value())
continue;
304 const int64_t new_dominating_ref_value =
305 dominating_ref_domain.ValueAtOrBefore(*dominating_ref_value +
308 if (!dominating_ref_domain.Contains(new_dominating_ref_value))
continue;
309 SetRefValue(dominating_ref, new_dominating_ref_value);
310 remaining_delta -= (new_dominating_ref_value - *dominating_ref_value);
311 if (remaining_delta == 0)
break;
316 std::optional<int> var_index, absl::Span<const int> vars,
317 absl::Span<const int64_t> coeffs, int64_t rhs) {
318 DCHECK_EQ(vars.size(), coeffs.size());
319 DCHECK(!var_index.has_value() || var_index.value() < vars.size());
320 if (!solution_is_loaded_)
return;
321 int64_t term_value = rhs;
322 for (
int i = 0;
i < vars.size(); ++
i) {
323 if (HasValue(vars[
i])) {
324 if (
i != var_index) {
325 term_value -= GetVarValue(vars[
i]) * coeffs[
i];
327 }
else if (!var_index.has_value()) {
329 }
else if (var_index.value() !=
i) {
333 if (!var_index.has_value())
return;
334 SetVarValue(vars[var_index.value()], term_value / coeffs[var_index.value()]);
337 if (term_value % coeffs[var_index.value()] != 0) {
338 std::stringstream lhs;
339 for (
int i = 0;
i < vars.size(); ++
i) {
340 lhs << (
i == var_index ?
"x" : std::to_string(GetVarValue(vars[
i])));
341 lhs <<
" * " << coeffs[
i];
342 if (
i < vars.size() - 1) lhs <<
" + ";
344 LOG(FATAL) <<
"Linear constraint incompatible with solution: " << lhs
351 const ReservoirConstraintProto& reservoir, int64_t min_level,
352 int64_t max_level, absl::Span<const int> level_vars,
353 const CircuitConstraintProto& circuit) {
354 if (!solution_is_loaded_)
return;
357 struct ReservoirEventValues {
360 int64_t level_change;
362 const int num_events = reservoir.time_exprs_size();
363 std::vector<ReservoirEventValues> active_event_values;
364 for (
int i = 0;
i < num_events; ++
i) {
365 if (!HasValue(
PositiveRef(reservoir.active_literals(
i))))
return;
366 if (GetLiteralValue(reservoir.active_literals(
i))) {
367 const std::optional<int64_t> time_value =
368 GetExpressionValue(reservoir.time_exprs(
i));
369 const std::optional<int64_t> change_value =
370 GetExpressionValue(reservoir.level_changes(
i));
371 if (!time_value.has_value() || !change_value.has_value())
return;
372 active_event_values.push_back(
373 {
i, time_value.value(), change_value.value()});
378 std::sort(active_event_values.begin(), active_event_values.end(),
379 [](
const ReservoirEventValues& a,
const ReservoirEventValues&
b) {
380 return a.time < b.time;
382 int64_t current_level = 0;
383 for (
int i = 0;
i < active_event_values.size(); ++
i) {
390 while (j < active_event_values.size() &&
391 active_event_values[j].time == active_event_values[
i].time &&
392 (current_level + active_event_values[j].level_change < min_level ||
393 current_level + active_event_values[j].level_change > max_level)) {
396 if (j < active_event_values.size() &&
397 active_event_values[j].time == active_event_values[
i].time) {
398 if (
i != j) std::swap(active_event_values[
i], active_event_values[j]);
399 current_level += active_event_values[
i].level_change;
400 SetVarValue(level_vars[active_event_values[
i].index], current_level);
408 std::vector<int> active_event_value_index(num_events, -1);
409 for (
int i = 0;
i < active_event_values.size(); ++
i) {
410 active_event_value_index[active_event_values[
i].index] =
i;
412 for (
int i = 0;
i < circuit.literals_size(); ++
i) {
413 const int head = circuit.heads(
i);
414 const int tail = circuit.tails(
i);
415 const int literal = circuit.literals(
i);
416 if (tail == num_events) {
417 if (head == num_events) {
419 SetLiteralValue(literal, active_event_values.empty());
422 SetLiteralValue(literal, !active_event_values.empty() &&
423 active_event_values.front().index == head);
425 }
else if (head == num_events) {
427 SetLiteralValue(literal, !active_event_values.empty() &&
428 active_event_values.back().index == tail);
429 }
else if (tail != head) {
431 const int tail_index = active_event_value_index[tail];
432 const int head_index = active_event_value_index[head];
433 SetLiteralValue(literal, tail_index != -1 && tail_index != -1 &&
434 head_index == tail_index + 1);
440 int var,
const LinearExpressionProto& time_i,
441 const LinearExpressionProto& time_j,
int active_i,
int active_j) {
442 if (!solution_is_loaded_)
return;
443 std::optional<int64_t> time_i_value = GetExpressionValue(time_i);
444 std::optional<int64_t> time_j_value = GetExpressionValue(time_j);
445 std::optional<int64_t> active_i_value = GetRefValue(active_i);
446 std::optional<int64_t> active_j_value = GetRefValue(active_j);
447 if (time_i_value.has_value() && time_j_value.has_value() &&
448 active_i_value.has_value() && active_j_value.has_value()) {
449 const bool reified_value = (active_i_value.value() != 0) &&
450 (active_j_value.value() != 0) &&
451 (time_i_value.value() <= time_j_value.value());
452 SetVarValue(var, reified_value);
457 int div_var,
int prod_var,
458 int64_t default_div_value,
459 int64_t default_prod_value) {
460 if (!solution_is_loaded_)
return;
461 bool enforced_value =
true;
462 for (
const int lit : ct.enforcement_literal()) {
464 enforced_value = enforced_value && GetLiteralValue(lit);
466 if (!enforced_value) {
467 SetVarValue(div_var, default_div_value);
468 SetVarValue(prod_var, default_prod_value);
471 const LinearArgumentProto& int_mod = ct.int_mod();
472 std::optional<int64_t> v = GetExpressionValue(int_mod.exprs(0));
473 if (!v.has_value())
return;
474 const int64_t expr_value = v.value();
476 v = GetExpressionValue(int_mod.exprs(1));
477 if (!v.has_value())
return;
478 const int64_t mod_expr_value = v.value();
480 v = GetExpressionValue(int_mod.target());
481 if (!v.has_value())
return;
482 const int64_t target_expr_value = v.value();
485 SetVarValue(div_var, expr_value / mod_expr_value);
486 SetVarValue(prod_var, expr_value - target_expr_value);
490 absl::Span<const int> prod_vars) {
491 DCHECK_EQ(int_prod.exprs_size(), prod_vars.size() + 2);
492 if (!solution_is_loaded_)
return;
493 std::optional<int64_t> v = GetExpressionValue(int_prod.exprs(0));
494 if (!v.has_value())
return;
495 int64_t last_prod_value = v.value();
496 for (
int i = 1;
i < int_prod.exprs_size() - 1; ++
i) {
497 v = GetExpressionValue(int_prod.exprs(
i));
498 if (!v.has_value())
return;
499 last_prod_value *= v.value();
500 SetVarValue(prod_vars[
i - 1], last_prod_value);
505 const LinearArgumentProto& lin_max,
506 absl::Span<const int> enforcement_lits) {
507 if (!solution_is_loaded_)
return;
508 DCHECK_EQ(enforcement_lits.size(), lin_max.exprs_size());
509 const std::optional<int64_t> target_value =
510 GetExpressionValue(lin_max.target());
511 if (!target_value.has_value())
return;
512 int enforcement_value_sum = 0;
513 for (
int i = 0;
i < enforcement_lits.size(); ++
i) {
514 const std::optional<int64_t> expr_value =
515 GetExpressionValue(lin_max.exprs(
i));
516 if (!expr_value.has_value())
return;
517 if (enforcement_value_sum == 0) {
518 const bool enforcement_value = target_value.value() <= expr_value.value();
519 SetLiteralValue(enforcement_lits[
i], enforcement_value);
520 enforcement_value_sum += enforcement_value;
522 SetLiteralValue(enforcement_lits[
i],
false);
528 const AutomatonConstraintProto& automaton,
529 absl::Span<const StateVar> state_vars,
530 absl::Span<const TransitionVar> transition_vars) {
531 if (!solution_is_loaded_)
return;
532 absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> transitions;
533 for (
int i = 0;
i < automaton.transition_tail_size(); ++
i) {
534 transitions[{automaton.transition_tail(
i), automaton.transition_label(
i)}] =
535 automaton.transition_head(
i);
538 std::vector<int64_t> label_values;
539 std::vector<int64_t> state_values;
540 int64_t current_state = automaton.starting_state();
541 state_values.push_back(current_state);
542 for (
int i = 0;
i < automaton.exprs_size(); ++
i) {
543 const std::optional<int64_t> label_value =
544 GetExpressionValue(automaton.exprs(
i));
545 if (!label_value.has_value())
return;
546 label_values.push_back(label_value.value());
548 const auto it = transitions.find({current_state, label_value.value()});
549 if (it == transitions.end())
return;
550 current_state = it->second;
551 state_values.push_back(current_state);
554 for (
const auto& [var, time, state] : state_vars) {
555 SetVarValue(var, state_values[time] == state);
557 for (
const auto& [var, time, transition_tail, transition_label] :
559 SetVarValue(var, state_values[time] == transition_tail &&
560 label_values[time] == transition_label);
565 absl::Span<const int> column_vars, absl::Span<const int> existing_row_lits,
566 absl::Span<const TableRowLiteral> new_row_lits) {
567 if (!solution_is_loaded_)
return;
568 int row_lit_values_sum = 0;
569 for (
const int lit : existing_row_lits) {
571 row_lit_values_sum += GetLiteralValue(lit);
573 const int num_vars = column_vars.size();
574 for (
const auto& [lit, var_values] : new_row_lits) {
575 if (row_lit_values_sum >= 1) {
576 SetLiteralValue(lit,
false);
579 bool row_lit_value =
true;
580 for (
int var_index = 0; var_index < num_vars; ++var_index) {
581 const auto& values = var_values[var_index];
582 if (!values.empty() &&
583 std::find(values.begin(), values.end(),
584 GetVarValue(column_vars[var_index])) == values.end()) {
585 row_lit_value =
false;
589 SetLiteralValue(lit, row_lit_value);
590 row_lit_values_sum += row_lit_value;
595 const LinearConstraintProto& linear, absl::Span<const int> bucket_lits) {
596 if (!solution_is_loaded_)
return;
597 int64_t expr_value = 0;
598 for (
int i = 0;
i < linear.vars_size(); ++
i) {
599 const int var = linear.vars(
i);
600 if (!HasValue(var))
return;
601 expr_value += linear.coeffs(
i) * GetVarValue(var);
603 DCHECK_LE(bucket_lits.size(), linear.domain_size() / 2);
604 for (
int i = 0;
i < bucket_lits.size(); ++
i) {
605 const int64_t lb = linear.domain(2 *
i);
606 const int64_t ub = linear.domain(2 *
i + 1);
607 SetLiteralValue(bucket_lits[
i], expr_value >= lb && expr_value <= ub);
612 if (!solution_is_loaded_)
return;
613 model.clear_solution_hint();
614 for (
int i = 0;
i < var_values_.size(); ++
i) {
615 if (var_has_value_[
i]) {
616 model.mutable_solution_hint()->add_vars(
i);
617 model.mutable_solution_hint()->add_values(var_values_[
i]);
623 CHECK(solution_is_loaded_);
630 absl::Span<const int> x_intervals, absl::Span<const int> y_intervals,
631 absl::Span<const BoxInAreaLiteral> box_in_area_lits) {
632 struct RectangleTypeAndIndex {
641 std::vector<Rectangle> rectangles_for_intersections;
642 std::vector<RectangleTypeAndIndex> rectangles_index;
644 for (
int i = 0;
i < x_intervals.size(); ++
i) {
645 const ConstraintProto& x_ct = model.constraints(x_intervals[
i]);
646 const ConstraintProto& y_ct = model.constraints(y_intervals[
i]);
648 const std::optional<int64_t> x_min =
649 GetExpressionValue(x_ct.interval().start());
650 const std::optional<int64_t> x_max =
651 GetExpressionValue(x_ct.interval().end());
652 const std::optional<int64_t> y_min =
653 GetExpressionValue(y_ct.interval().start());
654 const std::optional<int64_t> y_max =
655 GetExpressionValue(y_ct.interval().end());
657 if (!x_min.has_value() || !x_max.has_value() || !y_min.has_value() ||
658 !y_max.has_value()) {
661 if (*x_min > *x_max || *y_min > *y_max) {
662 VLOG(2) <<
"Hinted no_overlap_2d coordinate has max lower than min";
665 const Rectangle box = {.x_min = x_min.value(),
666 .x_max = x_max.value(),
667 .y_min = y_min.value(),
668 .y_max = y_max.value()};
669 rectangles_for_intersections.push_back(box);
670 rectangles_index.push_back(
671 {.index =
i, .type = RectangleTypeAndIndex::Type::kHintedBox});
674 for (
int i = 0;
i < areas.
size(); ++
i) {
676 rectangles_for_intersections.push_back(area);
677 rectangles_index.push_back(
678 {.index =
i, .type = RectangleTypeAndIndex::Type::kArea});
682 const std::vector<std::pair<int, int>> intersections =
685 absl::flat_hash_set<std::pair<int, int>> box_to_area_pairs;
687 for (
const auto& [rec1_index, rec2_index] : intersections) {
688 RectangleTypeAndIndex rec1 = rectangles_index[rec1_index];
689 RectangleTypeAndIndex rec2 = rectangles_index[rec2_index];
690 if (rec1.type == rec2.type) {
691 DCHECK(rec1.type == RectangleTypeAndIndex::Type::kHintedBox);
692 VLOG(2) <<
"Hinted position of boxes in no_overlap_2d are overlapping";
695 if (rec1.type != RectangleTypeAndIndex::Type::kHintedBox) {
696 std::swap(rec1, rec2);
699 box_to_area_pairs.insert({rec1.index, rec2.index});
702 for (
const auto& [box_index, area_index, literal] : box_in_area_lits) {
703 SetLiteralValue(literal,
704 box_to_area_pairs.contains({box_index, area_index}));
bool Contains(int64_t value) const
int64_t FixedValue() const
int64_t ClosestValue(int64_t wanted) const
void ApplyToDenseCollection(Collection &span) const
size_t size() const
Size of the "key" space, always in [0, size()).
void LoadSolution(int num_vars, const absl::flat_hash_map< int, int64_t > &solution)
void SetIntModExpandedVars(const ConstraintProto &ct, int div_var, int prod_var, int64_t default_div_value, int64_t default_prod_value)
void SetLinearWithComplexDomainExpandedVars(const LinearConstraintProto &linear, absl::Span< const int > bucket_lits)
void StoreSolutionAsHint(CpModelProto &model) const
Stores the solution as a hint in the given model.
void SetLinMaxExpandedVars(const LinearArgumentProto &lin_max, absl::Span< const int > enforcement_lits)
void SetVarToConjunction(int var, absl::Span< const int > conjunction)
void SetReservoirCircuitVars(const ReservoirConstraintProto &reservoir, int64_t min_level, int64_t max_level, absl::Span< const int > level_vars, const CircuitConstraintProto &circuit)
void SetVarToValueIf(int var, int64_t value, int condition_lit)
Sets the value of var to value if the value of condition_lit is true.
void Resize(int new_size)
void AssignVariableToPackingArea(const CompactVectorVector< int, Rectangle > &areas, const CpModelProto &model, absl::Span< const int > x_intervals, absl::Span< const int > y_intervals, absl::Span< const BoxInAreaLiteral > box_in_area_lits)
void SetAutomatonExpandedVars(const AutomatonConstraintProto &automaton, absl::Span< const StateVar > state_vars, absl::Span< const TransitionVar > transition_vars)
void SetVarToValueIfLinearConstraintViolated(int var, int64_t value, absl::Span< const std::pair< int, int64_t > > linear, const Domain &domain)
void SetVarToReifiedPrecedenceLiteral(int var, const LinearExpressionProto &time_i, const LinearExpressionProto &time_j, int active_i, int active_j)
void MakeLiteralsEqual(int lit1, int lit2)
void UpdateLiteralsWithDominance(int lit, int dominating_lit)
void SetLiteralToValueIfLinearConstraintViolated(int literal, bool value, absl::Span< const std::pair< int, int64_t > > linear, const Domain &domain)
void UpdateRefsWithDominance(int ref, int64_t min_value, int64_t max_value, absl::Span< const std::pair< int, Domain > > dominating_refs)
void SetIntProdExpandedVars(const LinearArgumentProto &int_prod, absl::Span< const int > prod_vars)
void SetVarToLinearExpressionIf(int var, const LinearExpressionProto &expr, int condition_lit)
void MaybeSetLiteralToValueEncoding(int literal, int var, int64_t value)
void SetOrUpdateVarToDomain(int var, const Domain &domain)
void MaybeUpdateVarWithSymmetriesToValue(int var, bool value, absl::Span< const std::unique_ptr< SparsePermutation > > generators)
void SetVarToLinearExpression(int var, absl::Span< const std::pair< int, int64_t > > linear, int64_t offset=0)
void UpdateLiteralsToFalseIfDifferent(int lit1, int lit2)
void SetLiteralToValueIf(int literal, bool value, int condition_lit)
void SetTableExpandedVars(absl::Span< const int > column_vars, absl::Span< const int > existing_row_lits, absl::Span< const TableRowLiteral > new_row_lits)
void SetVarToConditionalValue(int var, absl::Span< const int > condition_lits, int64_t value_if_true, int64_t value_if_false)
void SetVarToClause(int var, absl::Span< const int > clause)
void SetVarToLinearConstraintSolution(std::optional< int > var_index, absl::Span< const int > vars, absl::Span< const int64_t > coeffs, int64_t rhs)
void GetSchreierVectorAndOrbit(int point, absl::Span< const std::unique_ptr< SparsePermutation > > generators, std::vector< int > *schrier_vector, std::vector< int > *orbit)
bool RefIsPositive(int ref)
std::vector< int > TracePoint(int point, absl::Span< const int > schrier_vector, absl::Span< const std::unique_ptr< SparsePermutation > > generators)
std::vector< std::pair< int, int > > FindPartialRectangleIntersections(absl::Span< const Rectangle > rectangles)
In SWIG mode, we don't want anything besides these top-level includes.
Select next search node to expand Select next item_i to add this new search node to the search Generate a new search node where item_i is not in the knapsack Check validity of this new partial solution(using propagators) - If valid