Google OR-Tools v9.11
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
expressions.cc
Go to the documentation of this file.
1// Copyright 2010-2024 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cmath>
16#include <cstdint>
17#include <cstdlib>
18#include <limits>
19#include <memory>
20#include <string>
21#include <utility>
22#include <vector>
23
24#include "absl/container/flat_hash_map.h"
25#include "absl/strings/str_cat.h"
26#include "absl/strings/str_format.h"
27#include "absl/strings/string_view.h"
28#include "absl/types/span.h"
34#include "ortools/base/types.h"
37#include "ortools/util/bitset.h"
40
41ABSL_FLAG(bool, cp_disable_expression_optimization, false,
42 "Disable special optimization when creating expressions.");
43ABSL_FLAG(bool, cp_share_int_consts, true,
44 "Share IntConst's with the same value.");
45
46#if defined(_MSC_VER)
47#pragma warning(disable : 4351 4355)
48#endif
49
50namespace operations_research {
51
52// ---------- IntExpr ----------
53
54IntVar* IntExpr::VarWithName(const std::string& name) {
55 IntVar* const var = Var();
56 var->set_name(name);
57 return var;
58}
59
60// ---------- IntVar ----------
61
62IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
63
64IntVar::IntVar(Solver* const s, const std::string& name)
65 : IntExpr(s), index_(s->GetNewIntVarIndex()) {
67}
68
69// ----- Boolean variable -----
70
72
73void BooleanVar::SetMin(int64_t m) {
74 if (m <= 0) return;
75 if (m > 1) solver()->Fail();
76 SetValue(1);
77}
78
79void BooleanVar::SetMax(int64_t m) {
80 if (m >= 1) return;
81 if (m < 0) solver()->Fail();
82 SetValue(0);
83}
84
85void BooleanVar::SetRange(int64_t mi, int64_t ma) {
86 if (mi > 1 || ma < 0 || mi > ma) {
87 solver()->Fail();
88 }
89 if (mi == 1) {
90 SetValue(1);
91 } else if (ma == 0) {
92 SetValue(0);
93 }
94}
95
96void BooleanVar::RemoveValue(int64_t v) {
98 if (v == 0) {
99 SetValue(1);
100 } else if (v == 1) {
101 SetValue(0);
102 }
103 } else if (v == value_) {
104 solver()->Fail();
105 }
106}
107
108void BooleanVar::RemoveInterval(int64_t l, int64_t u) {
109 if (u < l) return;
110 if (l <= 0 && u >= 1) {
111 solver()->Fail();
112 } else if (l == 1) {
113 SetValue(0);
114 } else if (u == 0) {
115 SetValue(1);
116 }
117}
118
122 delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
123 } else {
124 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
125 }
126 }
127}
128
129uint64_t BooleanVar::Size() const {
130 return (1 + (value_ == kUnboundBooleanVarValue));
131}
132
133bool BooleanVar::Contains(int64_t v) const {
134 return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
135}
136
137IntVar* BooleanVar::IsEqual(int64_t constant) {
138 if (constant > 1 || constant < 0) {
139 return solver()->MakeIntConst(0);
140 }
141 if (constant == 1) {
142 return this;
143 } else { // constant == 0.
144 return solver()->MakeDifference(1, this)->Var();
145 }
146}
147
148IntVar* BooleanVar::IsDifferent(int64_t constant) {
149 if (constant > 1 || constant < 0) {
150 return solver()->MakeIntConst(1);
151 }
152 if (constant == 1) {
153 return solver()->MakeDifference(1, this)->Var();
154 } else { // constant == 0.
155 return this;
156 }
157}
158
160 if (constant > 1) {
161 return solver()->MakeIntConst(0);
162 } else if (constant <= 0) {
163 return solver()->MakeIntConst(1);
164 } else {
165 return this;
166 }
167}
168
170 if (constant < 0) {
171 return solver()->MakeIntConst(0);
172 } else if (constant >= 1) {
173 return solver()->MakeIntConst(1);
174 } else {
175 return IsEqual(0);
176 }
177}
178
179std::string BooleanVar::DebugString() const {
180 std::string out;
181 const std::string& var_name = name();
182 if (!var_name.empty()) {
183 out = var_name + "(";
184 } else {
185 out = "BooleanVar(";
186 }
187 switch (value_) {
188 case 0:
189 out += "0";
190 break;
191 case 1:
192 out += "1";
193 break;
195 out += "0 .. 1";
196 break;
197 }
198 out += ")";
199 return out;
200}
201
202namespace {
203// ---------- Subclasses of IntVar ----------
204
205// ----- Domain Int Var: base class for variables -----
206// It Contains bounds and a bitset representation of possible values.
207class DomainIntVar : public IntVar {
208 public:
209 // Utility classes
210 class BitSetIterator : public BaseObject {
211 public:
212 BitSetIterator(uint64_t* const bitset, int64_t omin)
213 : bitset_(bitset),
214 omin_(omin),
215 max_(std::numeric_limits<int64_t>::min()),
216 current_(std::numeric_limits<int64_t>::max()) {}
217
218 ~BitSetIterator() override {}
219
220 void Init(int64_t min, int64_t max) {
221 max_ = max;
222 current_ = min;
223 }
224
225 bool Ok() const { return current_ <= max_; }
226
227 int64_t Value() const { return current_; }
228
229 void Next() {
230 if (++current_ <= max_) {
232 bitset_, current_ - omin_, max_ - omin_) +
233 omin_;
234 }
235 }
236
237 std::string DebugString() const override { return "BitSetIterator"; }
238
239 private:
240 uint64_t* const bitset_;
241 const int64_t omin_;
242 int64_t max_;
243 int64_t current_;
244 };
245
246 class BitSet : public BaseObject {
247 public:
248 explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
249 ~BitSet() override {}
250
251 virtual int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) = 0;
252 virtual int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) = 0;
253 virtual bool Contains(int64_t val) const = 0;
254 virtual bool SetValue(int64_t val) = 0;
255 virtual bool RemoveValue(int64_t val) = 0;
256 virtual uint64_t Size() const = 0;
257 virtual void DelayRemoveValue(int64_t val) = 0;
258 virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
259 virtual void ClearRemovedValues() = 0;
260 virtual std::string pretty_DebugString(int64_t min, int64_t max) const = 0;
261 virtual BitSetIterator* MakeIterator() = 0;
262
263 void InitHoles() {
264 const uint64_t current_stamp = solver_->stamp();
265 if (holes_stamp_ < current_stamp) {
266 holes_.clear();
267 holes_stamp_ = current_stamp;
268 }
269 }
270
271 virtual void ClearHoles() { holes_.clear(); }
272
273 const std::vector<int64_t>& Holes() { return holes_; }
274
275 void AddHole(int64_t value) { holes_.push_back(value); }
276
277 int NumHoles() const {
278 return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
279 }
280
281 protected:
282 Solver* const solver_;
283
284 private:
285 std::vector<int64_t> holes_;
286 uint64_t holes_stamp_;
287 };
288
289 class QueueHandler : public Demon {
290 public:
291 explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
292 ~QueueHandler() override {}
293 void Run(Solver* const s) override {
294 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
295 var_->Process();
296 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
297 }
298 Solver::DemonPriority priority() const override {
300 }
301 std::string DebugString() const override {
302 return absl::StrFormat("Handler(%s)", var_->DebugString());
303 }
304
305 private:
306 DomainIntVar* const var_;
307 };
308
309 // Bounds and Value watchers
310
311 // This class stores the watchers variables attached to values. It is
312 // reversible and it helps maintaining the set of 'active' watchers
313 // (variables not bound to a single value).
314 template <class T>
315 class RevIntPtrMap {
316 public:
317 RevIntPtrMap(Solver* const solver, int64_t rmin, int64_t rmax)
318 : solver_(solver), range_min_(rmin), start_(0) {}
319
320 ~RevIntPtrMap() {}
321
322 bool Empty() const { return start_.Value() == elements_.size(); }
323
324 void SortActive() { std::sort(elements_.begin(), elements_.end()); }
325
326 // Access with value API.
327
328 // Add the pointer to the map attached to the given value.
329 void UnsafeRevInsert(int64_t value, T* elem) {
330 elements_.push_back(std::make_pair(value, elem));
331 if (solver_->state() != Solver::OUTSIDE_SEARCH) {
332 solver_->AddBacktrackAction(
333 [this, value](Solver* s) { Uninsert(value); }, false);
334 }
335 }
336
337 T* FindPtrOrNull(int64_t value, int* position) {
338 for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
339 if (elements_[pos].first == value) {
340 if (position != nullptr) *position = pos;
341 return At(pos).second;
342 }
343 }
344 return nullptr;
345 }
346
347 // Access map through the underlying vector.
348 void RemoveAt(int position) {
349 const int start = start_.Value();
350 DCHECK_GE(position, start);
351 DCHECK_LT(position, elements_.size());
352 if (position > start) {
353 // Swap the current element with the one at the start position, and
354 // increase start.
355 const std::pair<int64_t, T*> copy = elements_[start];
356 elements_[start] = elements_[position];
357 elements_[position] = copy;
358 }
359 start_.Incr(solver_);
360 }
361
362 const std::pair<int64_t, T*>& At(int position) const {
363 DCHECK_GE(position, start_.Value());
364 DCHECK_LT(position, elements_.size());
365 return elements_[position];
366 }
367
368 void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
369
370 int start() const { return start_.Value(); }
371 int end() const { return elements_.size(); }
372 // Number of active elements.
373 int Size() const { return elements_.size() - start_.Value(); }
374
375 // Removes the object permanently from the map.
376 void Uninsert(int64_t value) {
377 for (int pos = 0; pos < elements_.size(); ++pos) {
378 if (elements_[pos].first == value) {
379 DCHECK_GE(pos, start_.Value());
380 const int last = elements_.size() - 1;
381 if (pos != last) { // Swap the current with the last.
382 elements_[pos] = elements_.back();
383 }
384 elements_.pop_back();
385 return;
386 }
387 }
388 LOG(FATAL) << "The element should have been removed";
389 }
390
391 private:
392 Solver* const solver_;
393 const int64_t range_min_;
394 NumericalRev<int> start_;
395 std::vector<std::pair<int64_t, T*>> elements_;
396 };
397
398 // Base class for value watchers
399 class BaseValueWatcher : public Constraint {
400 public:
401 explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
402
403 ~BaseValueWatcher() override {}
404
405 virtual IntVar* GetOrMakeValueWatcher(int64_t value) = 0;
406
407 virtual void SetValueWatcher(IntVar* boolvar, int64_t value) = 0;
408 };
409
410 // This class monitors the domain of the variable and updates the
411 // IsEqual/IsDifferent boolean variables accordingly.
412 class ValueWatcher : public BaseValueWatcher {
413 public:
414 class WatchDemon : public Demon {
415 public:
416 WatchDemon(ValueWatcher* const watcher, int64_t value, IntVar* var)
417 : value_watcher_(watcher), value_(value), var_(var) {}
418 ~WatchDemon() override {}
419
420 void Run(Solver* const solver) override {
421 value_watcher_->ProcessValueWatcher(value_, var_);
422 }
423
424 private:
425 ValueWatcher* const value_watcher_;
426 const int64_t value_;
427 IntVar* const var_;
428 };
429
430 class VarDemon : public Demon {
431 public:
432 explicit VarDemon(ValueWatcher* const watcher)
433 : value_watcher_(watcher) {}
434
435 ~VarDemon() override {}
436
437 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
438
439 private:
440 ValueWatcher* const value_watcher_;
441 };
442
443 ValueWatcher(Solver* const solver, DomainIntVar* const variable)
444 : BaseValueWatcher(solver),
445 variable_(variable),
446 hole_iterator_(variable_->MakeHoleIterator(true)),
447 var_demon_(nullptr),
448 watchers_(solver, variable->Min(), variable->Max()) {}
449
450 ~ValueWatcher() override {}
451
452 IntVar* GetOrMakeValueWatcher(int64_t value) override {
453 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
454 if (watcher != nullptr) return watcher;
455 if (variable_->Contains(value)) {
456 if (variable_->Bound()) {
457 return solver()->MakeIntConst(1);
458 } else {
459 const std::string vname = variable_->HasName()
460 ? variable_->name()
461 : variable_->DebugString();
462 const std::string bname =
463 absl::StrFormat("Watch<%s == %d>", vname, value);
464 IntVar* const boolvar = solver()->MakeBoolVar(bname);
465 watchers_.UnsafeRevInsert(value, boolvar);
466 if (posted_.Switched()) {
467 boolvar->WhenBound(
468 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
469 var_demon_->desinhibit(solver());
470 }
471 return boolvar;
472 }
473 } else {
474 return variable_->solver()->MakeIntConst(0);
475 }
476 }
477
478 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
479 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
480 if (!boolvar->Bound()) {
481 watchers_.UnsafeRevInsert(value, boolvar);
482 if (posted_.Switched() && !boolvar->Bound()) {
483 boolvar->WhenBound(
484 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
485 var_demon_->desinhibit(solver());
486 }
487 }
488 }
489
490 void Post() override {
491 var_demon_ = solver()->RevAlloc(new VarDemon(this));
492 variable_->WhenDomain(var_demon_);
493 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
494 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
495 const int64_t value = w.first;
496 IntVar* const boolvar = w.second;
497 if (!boolvar->Bound() && variable_->Contains(value)) {
498 boolvar->WhenBound(
499 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
500 }
501 }
502 posted_.Switch(solver());
503 }
504
505 void InitialPropagate() override {
506 if (variable_->Bound()) {
507 VariableBound();
508 } else {
509 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
510 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
511 const int64_t value = w.first;
512 IntVar* const boolvar = w.second;
513 if (!variable_->Contains(value)) {
514 boolvar->SetValue(0);
515 watchers_.RemoveAt(pos);
516 } else {
517 if (boolvar->Bound()) {
518 ProcessValueWatcher(value, boolvar);
519 watchers_.RemoveAt(pos);
520 }
521 }
522 }
523 CheckInhibit();
524 }
525 }
526
527 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
528 if (boolvar->Min() == 0) {
529 if (variable_->Size() < 0xFFFFFF) {
530 variable_->RemoveValue(value);
531 } else {
532 // Delay removal.
533 solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
534 }
535 } else {
536 variable_->SetValue(value);
537 }
538 }
539
540 void ProcessVar() {
541 const int kSmallList = 16;
542 if (variable_->Bound()) {
543 VariableBound();
544 } else if (watchers_.Size() <= kSmallList ||
545 variable_->Min() != variable_->OldMin() ||
546 variable_->Max() != variable_->OldMax()) {
547 // Brute force loop for small numbers of watchers, or if the bounds have
548 // changed, which would have required a sort (n log(n)) anyway to take
549 // advantage of.
550 ScanWatchers();
551 CheckInhibit();
552 } else {
553 // If there is no bitset, then there are no holes.
554 // In that case, the two loops above should have performed all
555 // propagation. Otherwise, scan the remaining watchers.
556 BitSet* const bitset = variable_->bitset();
557 if (bitset != nullptr && !watchers_.Empty()) {
558 if (bitset->NumHoles() * 2 < watchers_.Size()) {
559 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
560 int pos = 0;
561 IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
562 if (boolvar != nullptr) {
563 boolvar->SetValue(0);
564 watchers_.RemoveAt(pos);
565 }
566 }
567 } else {
568 ScanWatchers();
569 }
570 }
571 CheckInhibit();
572 }
573 }
574
575 // Optimized case if the variable is bound.
576 void VariableBound() {
577 DCHECK(variable_->Bound());
578 const int64_t value = variable_->Min();
579 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
580 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
581 w.second->SetValue(w.first == value);
582 }
583 watchers_.RemoveAll();
584 var_demon_->inhibit(solver());
585 }
586
587 // Scans all the watchers to check and assign them.
588 void ScanWatchers() {
589 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
590 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
591 if (!variable_->Contains(w.first)) {
592 IntVar* const boolvar = w.second;
593 boolvar->SetValue(0);
594 watchers_.RemoveAt(pos);
595 }
596 }
597 }
598
599 // If the set of active watchers is empty, we can inhibit the demon on the
600 // main variable.
601 void CheckInhibit() {
602 if (watchers_.Empty()) {
603 var_demon_->inhibit(solver());
604 }
605 }
606
607 void Accept(ModelVisitor* const visitor) const override {
608 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
609 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
610 variable_);
611 std::vector<int64_t> all_coefficients;
612 std::vector<IntVar*> all_bool_vars;
613 for (int position = watchers_.start(); position < watchers_.end();
614 ++position) {
615 const std::pair<int64_t, IntVar*>& w = watchers_.At(position);
616 all_coefficients.push_back(w.first);
617 all_bool_vars.push_back(w.second);
618 }
619 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
620 all_bool_vars);
621 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
622 all_coefficients);
623 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
624 }
625
626 std::string DebugString() const override {
627 return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
628 }
629
630 private:
631 DomainIntVar* const variable_;
632 IntVarIterator* const hole_iterator_;
633 RevSwitch posted_;
634 Demon* var_demon_;
635 RevIntPtrMap<IntVar> watchers_;
636 };
637
638 // Optimized case for small maps.
639 class DenseValueWatcher : public BaseValueWatcher {
640 public:
641 class WatchDemon : public Demon {
642 public:
643 WatchDemon(DenseValueWatcher* const watcher, int64_t value, IntVar* var)
644 : value_watcher_(watcher), value_(value), var_(var) {}
645 ~WatchDemon() override {}
646
647 void Run(Solver* const solver) override {
648 value_watcher_->ProcessValueWatcher(value_, var_);
649 }
650
651 private:
652 DenseValueWatcher* const value_watcher_;
653 const int64_t value_;
654 IntVar* const var_;
655 };
656
657 class VarDemon : public Demon {
658 public:
659 explicit VarDemon(DenseValueWatcher* const watcher)
660 : value_watcher_(watcher) {}
661
662 ~VarDemon() override {}
663
664 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
665
666 private:
667 DenseValueWatcher* const value_watcher_;
668 };
669
670 DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
671 : BaseValueWatcher(solver),
672 variable_(variable),
673 hole_iterator_(variable_->MakeHoleIterator(true)),
674 var_demon_(nullptr),
675 offset_(variable->Min()),
676 watchers_(variable->Max() - variable->Min() + 1, nullptr),
677 active_watchers_(0) {}
678
679 ~DenseValueWatcher() override {}
680
681 IntVar* GetOrMakeValueWatcher(int64_t value) override {
682 const int64_t var_max = offset_ + watchers_.size() - 1; // Bad cast.
683 if (value < offset_ || value > var_max) {
684 return solver()->MakeIntConst(0);
685 }
686 const int index = value - offset_;
687 IntVar* const watcher = watchers_[index];
688 if (watcher != nullptr) return watcher;
689 if (variable_->Contains(value)) {
690 if (variable_->Bound()) {
691 return solver()->MakeIntConst(1);
692 } else {
693 const std::string vname = variable_->HasName()
694 ? variable_->name()
695 : variable_->DebugString();
696 const std::string bname =
697 absl::StrFormat("Watch<%s == %d>", vname, value);
698 IntVar* const boolvar = solver()->MakeBoolVar(bname);
699 RevInsert(index, boolvar);
700 if (posted_.Switched()) {
701 boolvar->WhenBound(
702 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
703 var_demon_->desinhibit(solver());
704 }
705 return boolvar;
706 }
707 } else {
708 return variable_->solver()->MakeIntConst(0);
709 }
710 }
711
712 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
713 const int index = value - offset_;
714 CHECK(watchers_[index] == nullptr);
715 if (!boolvar->Bound()) {
716 RevInsert(index, boolvar);
717 if (posted_.Switched() && !boolvar->Bound()) {
718 boolvar->WhenBound(
719 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
720 var_demon_->desinhibit(solver());
721 }
722 }
723 }
724
725 void Post() override {
726 var_demon_ = solver()->RevAlloc(new VarDemon(this));
727 variable_->WhenDomain(var_demon_);
728 for (int pos = 0; pos < watchers_.size(); ++pos) {
729 const int64_t value = pos + offset_;
730 IntVar* const boolvar = watchers_[pos];
731 if (boolvar != nullptr && !boolvar->Bound() &&
732 variable_->Contains(value)) {
733 boolvar->WhenBound(
734 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
735 }
736 }
737 posted_.Switch(solver());
738 }
739
740 void InitialPropagate() override {
741 if (variable_->Bound()) {
742 VariableBound();
743 } else {
744 for (int pos = 0; pos < watchers_.size(); ++pos) {
745 IntVar* const boolvar = watchers_[pos];
746 if (boolvar == nullptr) continue;
747 const int64_t value = pos + offset_;
748 if (!variable_->Contains(value)) {
749 boolvar->SetValue(0);
750 RevRemove(pos);
751 } else if (boolvar->Bound()) {
752 ProcessValueWatcher(value, boolvar);
753 RevRemove(pos);
754 }
755 }
756 if (active_watchers_.Value() == 0) {
757 var_demon_->inhibit(solver());
758 }
759 }
760 }
761
762 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
763 if (boolvar->Min() == 0) {
764 variable_->RemoveValue(value);
765 } else {
766 variable_->SetValue(value);
767 }
768 }
769
770 void ProcessVar() {
771 if (variable_->Bound()) {
772 VariableBound();
773 } else {
774 // Brute force loop for small numbers of watchers.
775 ScanWatchers();
776 if (active_watchers_.Value() == 0) {
777 var_demon_->inhibit(solver());
778 }
779 }
780 }
781
782 // Optimized case if the variable is bound.
783 void VariableBound() {
784 DCHECK(variable_->Bound());
785 const int64_t value = variable_->Min();
786 for (int pos = 0; pos < watchers_.size(); ++pos) {
787 IntVar* const boolvar = watchers_[pos];
788 if (boolvar != nullptr) {
789 boolvar->SetValue(pos + offset_ == value);
790 RevRemove(pos);
791 }
792 }
793 var_demon_->inhibit(solver());
794 }
795
796 // Scans all the watchers to check and assign them.
797 void ScanWatchers() {
798 const int64_t old_min_index = variable_->OldMin() - offset_;
799 const int64_t old_max_index = variable_->OldMax() - offset_;
800 const int64_t min_index = variable_->Min() - offset_;
801 const int64_t max_index = variable_->Max() - offset_;
802 for (int pos = old_min_index; pos < min_index; ++pos) {
803 IntVar* const boolvar = watchers_[pos];
804 if (boolvar != nullptr) {
805 boolvar->SetValue(0);
806 RevRemove(pos);
807 }
808 }
809 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
810 IntVar* const boolvar = watchers_[pos];
811 if (boolvar != nullptr) {
812 boolvar->SetValue(0);
813 RevRemove(pos);
814 }
815 }
816 BitSet* const bitset = variable_->bitset();
817 if (bitset != nullptr) {
818 if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
819 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
820 IntVar* const boolvar = watchers_[hole - offset_];
821 if (boolvar != nullptr) {
822 boolvar->SetValue(0);
823 RevRemove(hole - offset_);
824 }
825 }
826 } else {
827 for (int pos = min_index + 1; pos < max_index; ++pos) {
828 IntVar* const boolvar = watchers_[pos];
829 if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
830 boolvar->SetValue(0);
831 RevRemove(pos);
832 }
833 }
834 }
835 }
836 }
837
838 void RevRemove(int pos) {
839 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
840 watchers_[pos] = nullptr;
841 active_watchers_.Decr(solver());
842 }
843
844 void RevInsert(int pos, IntVar* boolvar) {
845 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
846 watchers_[pos] = boolvar;
847 active_watchers_.Incr(solver());
848 }
849
850 void Accept(ModelVisitor* const visitor) const override {
851 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
852 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
853 variable_);
854 std::vector<int64_t> all_coefficients;
855 std::vector<IntVar*> all_bool_vars;
856 for (int position = 0; position < watchers_.size(); ++position) {
857 if (watchers_[position] != nullptr) {
858 all_coefficients.push_back(position + offset_);
859 all_bool_vars.push_back(watchers_[position]);
860 }
861 }
862 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
863 all_bool_vars);
864 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
865 all_coefficients);
866 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
867 }
868
869 std::string DebugString() const override {
870 return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
871 }
872
873 private:
874 DomainIntVar* const variable_;
875 IntVarIterator* const hole_iterator_;
876 RevSwitch posted_;
877 Demon* var_demon_;
878 const int64_t offset_;
879 std::vector<IntVar*> watchers_;
880 NumericalRev<int> active_watchers_;
881 };
882
883 class BaseUpperBoundWatcher : public Constraint {
884 public:
885 explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
886
887 ~BaseUpperBoundWatcher() override {}
888
889 virtual IntVar* GetOrMakeUpperBoundWatcher(int64_t value) = 0;
890
891 virtual void SetUpperBoundWatcher(IntVar* boolvar, int64_t value) = 0;
892 };
893
894 // This class watches the bounds of the variable and updates the
895 // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
896 // accordingly.
897 class UpperBoundWatcher : public BaseUpperBoundWatcher {
898 public:
899 class WatchDemon : public Demon {
900 public:
901 WatchDemon(UpperBoundWatcher* const watcher, int64_t index,
902 IntVar* const var)
903 : value_watcher_(watcher), index_(index), var_(var) {}
904 ~WatchDemon() override {}
905
906 void Run(Solver* const solver) override {
907 value_watcher_->ProcessUpperBoundWatcher(index_, var_);
908 }
909
910 private:
911 UpperBoundWatcher* const value_watcher_;
912 const int64_t index_;
913 IntVar* const var_;
914 };
915
916 class VarDemon : public Demon {
917 public:
918 explicit VarDemon(UpperBoundWatcher* const watcher)
919 : value_watcher_(watcher) {}
920 ~VarDemon() override {}
921
922 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
923
924 private:
925 UpperBoundWatcher* const value_watcher_;
926 };
927
928 UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
929 : BaseUpperBoundWatcher(solver),
930 variable_(variable),
931 var_demon_(nullptr),
932 watchers_(solver, variable->Min(), variable->Max()),
933 start_(0),
934 end_(0),
935 sorted_(false) {}
936
937 ~UpperBoundWatcher() override {}
938
939 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
940 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
941 if (watcher != nullptr) {
942 return watcher;
943 }
944 if (variable_->Max() >= value) {
945 if (variable_->Min() >= value) {
946 return solver()->MakeIntConst(1);
947 } else {
948 const std::string vname = variable_->HasName()
949 ? variable_->name()
950 : variable_->DebugString();
951 const std::string bname =
952 absl::StrFormat("Watch<%s >= %d>", vname, value);
953 IntVar* const boolvar = solver()->MakeBoolVar(bname);
954 watchers_.UnsafeRevInsert(value, boolvar);
955 if (posted_.Switched()) {
956 boolvar->WhenBound(
957 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
958 var_demon_->desinhibit(solver());
959 sorted_ = false;
960 }
961 return boolvar;
962 }
963 } else {
964 return variable_->solver()->MakeIntConst(0);
965 }
966 }
967
968 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
969 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
970 watchers_.UnsafeRevInsert(value, boolvar);
971 if (posted_.Switched() && !boolvar->Bound()) {
972 boolvar->WhenBound(
973 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
974 var_demon_->desinhibit(solver());
975 sorted_ = false;
976 }
977 }
978
979 void Post() override {
980 const int kTooSmallToSort = 8;
981 var_demon_ = solver()->RevAlloc(new VarDemon(this));
982 variable_->WhenRange(var_demon_);
983
984 if (watchers_.Size() > kTooSmallToSort) {
985 watchers_.SortActive();
986 sorted_ = true;
987 start_.SetValue(solver(), watchers_.start());
988 end_.SetValue(solver(), watchers_.end() - 1);
989 }
990
991 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
992 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
993 IntVar* const boolvar = w.second;
994 const int64_t value = w.first;
995 if (!boolvar->Bound() && value > variable_->Min() &&
997 boolvar->WhenBound(
998 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
999 }
1000 }
1001 posted_.Switch(solver());
1002 }
1003
1004 void InitialPropagate() override {
1005 const int64_t var_min = variable_->Min();
1006 const int64_t var_max = variable_->Max();
1007 if (sorted_) {
1008 while (start_.Value() <= end_.Value()) {
1009 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1010 if (w.first <= var_min) {
1011 w.second->SetValue(1);
1012 start_.Incr(solver());
1013 } else {
1014 break;
1015 }
1016 }
1017 while (end_.Value() >= start_.Value()) {
1018 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1019 if (w.first > var_max) {
1020 w.second->SetValue(0);
1021 end_.Decr(solver());
1022 } else {
1023 break;
1024 }
1025 }
1026 for (int i = start_.Value(); i <= end_.Value(); ++i) {
1027 const std::pair<int64_t, IntVar*>& w = watchers_.At(i);
1028 if (w.second->Bound()) {
1029 ProcessUpperBoundWatcher(w.first, w.second);
1030 }
1031 }
1032 if (start_.Value() > end_.Value()) {
1033 var_demon_->inhibit(solver());
1034 }
1035 } else {
1036 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1037 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1038 const int64_t value = w.first;
1039 IntVar* const boolvar = w.second;
1040
1041 if (value <= var_min) {
1042 boolvar->SetValue(1);
1043 watchers_.RemoveAt(pos);
1044 } else if (value > var_max) {
1045 boolvar->SetValue(0);
1046 watchers_.RemoveAt(pos);
1047 } else if (boolvar->Bound()) {
1048 ProcessUpperBoundWatcher(value, boolvar);
1049 watchers_.RemoveAt(pos);
1050 }
1051 }
1052 }
1053 }
1054
1055 void Accept(ModelVisitor* const visitor) const override {
1056 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1057 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1058 variable_);
1059 std::vector<int64_t> all_coefficients;
1060 std::vector<IntVar*> all_bool_vars;
1061 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1062 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1063 all_coefficients.push_back(w.first);
1064 all_bool_vars.push_back(w.second);
1065 }
1066 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1067 all_bool_vars);
1068 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1069 all_coefficients);
1070 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1071 }
1072
1073 std::string DebugString() const override {
1074 return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1075 }
1076
1077 private:
1078 void ProcessUpperBoundWatcher(int64_t value, IntVar* const boolvar) {
1079 if (boolvar->Min() == 0) {
1080 variable_->SetMax(value - 1);
1081 } else {
1082 variable_->SetMin(value);
1083 }
1084 }
1085
1086 void ProcessVar() {
1087 const int64_t var_min = variable_->Min();
1088 const int64_t var_max = variable_->Max();
1089 if (sorted_) {
1090 while (start_.Value() <= end_.Value()) {
1091 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1092 if (w.first <= var_min) {
1093 w.second->SetValue(1);
1094 start_.Incr(solver());
1095 } else {
1096 break;
1097 }
1098 }
1099 while (end_.Value() >= start_.Value()) {
1100 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1101 if (w.first > var_max) {
1102 w.second->SetValue(0);
1103 end_.Decr(solver());
1104 } else {
1105 break;
1106 }
1107 }
1108 if (start_.Value() > end_.Value()) {
1109 var_demon_->inhibit(solver());
1110 }
1111 } else {
1112 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1113 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1114 const int64_t value = w.first;
1115 IntVar* const boolvar = w.second;
1116
1117 if (value <= var_min) {
1118 boolvar->SetValue(1);
1119 watchers_.RemoveAt(pos);
1120 } else if (value > var_max) {
1121 boolvar->SetValue(0);
1122 watchers_.RemoveAt(pos);
1123 }
1124 }
1125 if (watchers_.Empty()) {
1126 var_demon_->inhibit(solver());
1127 }
1128 }
1129 }
1130
1131 DomainIntVar* const variable_;
1132 RevSwitch posted_;
1133 Demon* var_demon_;
1134 RevIntPtrMap<IntVar> watchers_;
1135 NumericalRev<int> start_;
1136 NumericalRev<int> end_;
1137 bool sorted_;
1138 };
1139
1140 // Optimized case for small maps.
1141 class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1142 public:
1143 class WatchDemon : public Demon {
1144 public:
1145 WatchDemon(DenseUpperBoundWatcher* const watcher, int64_t value,
1146 IntVar* var)
1147 : value_watcher_(watcher), value_(value), var_(var) {}
1148 ~WatchDemon() override {}
1149
1150 void Run(Solver* const solver) override {
1151 value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1152 }
1153
1154 private:
1155 DenseUpperBoundWatcher* const value_watcher_;
1156 const int64_t value_;
1157 IntVar* const var_;
1158 };
1159
1160 class VarDemon : public Demon {
1161 public:
1162 explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1163 : value_watcher_(watcher) {}
1164
1165 ~VarDemon() override {}
1166
1167 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1168
1169 private:
1170 DenseUpperBoundWatcher* const value_watcher_;
1171 };
1172
1173 DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1174 : BaseUpperBoundWatcher(solver),
1175 variable_(variable),
1176 var_demon_(nullptr),
1177 offset_(variable->Min()),
1178 watchers_(variable->Max() - variable->Min() + 1, nullptr),
1179 active_watchers_(0) {}
1180
1181 ~DenseUpperBoundWatcher() override {}
1182
1183 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
1184 if (variable_->Max() >= value) {
1185 if (variable_->Min() >= value) {
1186 return solver()->MakeIntConst(1);
1187 } else {
1188 const std::string vname = variable_->HasName()
1189 ? variable_->name()
1190 : variable_->DebugString();
1191 const std::string bname =
1192 absl::StrFormat("Watch<%s >= %d>", vname, value);
1193 IntVar* const boolvar = solver()->MakeBoolVar(bname);
1194 RevInsert(value - offset_, boolvar);
1195 if (posted_.Switched()) {
1196 boolvar->WhenBound(
1197 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1198 var_demon_->desinhibit(solver());
1199 }
1200 return boolvar;
1201 }
1202 } else {
1203 return variable_->solver()->MakeIntConst(0);
1204 }
1205 }
1206
1207 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
1208 const int index = value - offset_;
1209 CHECK(watchers_[index] == nullptr);
1210 if (!boolvar->Bound()) {
1211 RevInsert(index, boolvar);
1212 if (posted_.Switched() && !boolvar->Bound()) {
1213 boolvar->WhenBound(
1214 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1215 var_demon_->desinhibit(solver());
1216 }
1217 }
1218 }
1219
1220 void Post() override {
1221 var_demon_ = solver()->RevAlloc(new VarDemon(this));
1222 variable_->WhenRange(var_demon_);
1223 for (int pos = 0; pos < watchers_.size(); ++pos) {
1224 const int64_t value = pos + offset_;
1225 IntVar* const boolvar = watchers_[pos];
1226 if (boolvar != nullptr && !boolvar->Bound() &&
1227 value > variable_->Min() && value <= variable_->Max()) {
1228 boolvar->WhenBound(
1229 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1230 }
1231 }
1232 posted_.Switch(solver());
1233 }
1234
1235 void InitialPropagate() override {
1236 for (int pos = 0; pos < watchers_.size(); ++pos) {
1237 IntVar* const boolvar = watchers_[pos];
1238 if (boolvar == nullptr) continue;
1239 const int64_t value = pos + offset_;
1240 if (value <= variable_->Min()) {
1241 boolvar->SetValue(1);
1242 RevRemove(pos);
1243 } else if (value > variable_->Max()) {
1244 boolvar->SetValue(0);
1245 RevRemove(pos);
1246 } else if (boolvar->Bound()) {
1247 ProcessUpperBoundWatcher(value, boolvar);
1248 RevRemove(pos);
1249 }
1250 }
1251 if (active_watchers_.Value() == 0) {
1252 var_demon_->inhibit(solver());
1253 }
1254 }
1255
1256 void ProcessUpperBoundWatcher(int64_t value, IntVar* boolvar) {
1257 if (boolvar->Min() == 0) {
1258 variable_->SetMax(value - 1);
1259 } else {
1260 variable_->SetMin(value);
1261 }
1262 }
1263
1264 void ProcessVar() {
1265 const int64_t old_min_index = variable_->OldMin() - offset_;
1266 const int64_t old_max_index = variable_->OldMax() - offset_;
1267 const int64_t min_index = variable_->Min() - offset_;
1268 const int64_t max_index = variable_->Max() - offset_;
1269 for (int pos = old_min_index; pos <= min_index; ++pos) {
1270 IntVar* const boolvar = watchers_[pos];
1271 if (boolvar != nullptr) {
1272 boolvar->SetValue(1);
1273 RevRemove(pos);
1274 }
1275 }
1276
1277 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1278 IntVar* const boolvar = watchers_[pos];
1279 if (boolvar != nullptr) {
1280 boolvar->SetValue(0);
1281 RevRemove(pos);
1282 }
1283 }
1284 if (active_watchers_.Value() == 0) {
1285 var_demon_->inhibit(solver());
1286 }
1287 }
1288
1289 void RevRemove(int pos) {
1290 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1291 watchers_[pos] = nullptr;
1292 active_watchers_.Decr(solver());
1293 }
1294
1295 void RevInsert(int pos, IntVar* boolvar) {
1296 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1297 watchers_[pos] = boolvar;
1298 active_watchers_.Incr(solver());
1299 }
1300
1301 void Accept(ModelVisitor* const visitor) const override {
1302 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1303 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1304 variable_);
1305 std::vector<int64_t> all_coefficients;
1306 std::vector<IntVar*> all_bool_vars;
1307 for (int position = 0; position < watchers_.size(); ++position) {
1308 if (watchers_[position] != nullptr) {
1309 all_coefficients.push_back(position + offset_);
1310 all_bool_vars.push_back(watchers_[position]);
1311 }
1312 }
1313 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1314 all_bool_vars);
1315 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1316 all_coefficients);
1317 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1318 }
1319
1320 std::string DebugString() const override {
1321 return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1322 variable_->DebugString());
1323 }
1324
1325 private:
1326 DomainIntVar* const variable_;
1327 RevSwitch posted_;
1328 Demon* var_demon_;
1329 const int64_t offset_;
1330 std::vector<IntVar*> watchers_;
1331 NumericalRev<int> active_watchers_;
1332 };
1333
1334 // ----- Main Class -----
1335 DomainIntVar(Solver* s, int64_t vmin, int64_t vmax, const std::string& name);
1336 DomainIntVar(Solver* s, absl::Span<const int64_t> sorted_values,
1337 const std::string& name);
1338 ~DomainIntVar() override;
1339
1340 int64_t Min() const override { return min_.Value(); }
1341 void SetMin(int64_t m) override;
1342 int64_t Max() const override { return max_.Value(); }
1343 void SetMax(int64_t m) override;
1344 void SetRange(int64_t mi, int64_t ma) override;
1345 void SetValue(int64_t v) override;
1346 bool Bound() const override { return (min_.Value() == max_.Value()); }
1347 int64_t Value() const override {
1348 CHECK_EQ(min_.Value(), max_.Value())
1349 << " variable " << DebugString() << " is not bound.";
1350 return min_.Value();
1351 }
1352 void RemoveValue(int64_t v) override;
1353 void RemoveInterval(int64_t l, int64_t u) override;
1354 void CreateBits();
1355 void WhenBound(Demon* d) override {
1356 if (min_.Value() != max_.Value()) {
1357 if (d->priority() == Solver::DELAYED_PRIORITY) {
1358 delayed_bound_demons_.PushIfNotTop(solver(),
1359 solver()->RegisterDemon(d));
1360 } else {
1361 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1362 }
1363 }
1364 }
1365 void WhenRange(Demon* d) override {
1366 if (min_.Value() != max_.Value()) {
1367 if (d->priority() == Solver::DELAYED_PRIORITY) {
1368 delayed_range_demons_.PushIfNotTop(solver(),
1369 solver()->RegisterDemon(d));
1370 } else {
1371 range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1372 }
1373 }
1374 }
1375 void WhenDomain(Demon* d) override {
1376 if (min_.Value() != max_.Value()) {
1377 if (d->priority() == Solver::DELAYED_PRIORITY) {
1378 delayed_domain_demons_.PushIfNotTop(solver(),
1379 solver()->RegisterDemon(d));
1380 } else {
1381 domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1382 }
1383 }
1384 }
1385
1386 IntVar* IsEqual(int64_t constant) override {
1387 Solver* const s = solver();
1388 if (constant == min_.Value() && value_watcher_ == nullptr) {
1389 return s->MakeIsLessOrEqualCstVar(this, constant);
1390 }
1391 if (constant == max_.Value() && value_watcher_ == nullptr) {
1392 return s->MakeIsGreaterOrEqualCstVar(this, constant);
1393 }
1394 if (!Contains(constant)) {
1395 return s->MakeIntConst(int64_t{0});
1396 }
1397 if (Bound() && min_.Value() == constant) {
1398 return s->MakeIntConst(int64_t{1});
1399 }
1400 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1401 this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1402 if (cache != nullptr) {
1403 return cache->Var();
1404 } else {
1405 if (value_watcher_ == nullptr) {
1406 if (CapSub(Max(), Min()) <= 256) {
1407 solver()->SaveAndSetValue(
1408 reinterpret_cast<void**>(&value_watcher_),
1409 reinterpret_cast<void*>(
1410 solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1411
1412 } else {
1413 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1414 reinterpret_cast<void*>(solver()->RevAlloc(
1415 new ValueWatcher(solver(), this))));
1416 }
1417 solver()->AddConstraint(value_watcher_);
1418 }
1419 IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1420 s->Cache()->InsertExprConstantExpression(
1421 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1422 return boolvar;
1423 }
1424 }
1425
1426 Constraint* SetIsEqual(absl::Span<const int64_t> values,
1427 const std::vector<IntVar*>& vars) {
1428 if (value_watcher_ == nullptr) {
1429 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1430 reinterpret_cast<void*>(solver()->RevAlloc(
1431 new ValueWatcher(solver(), this))));
1432 for (int i = 0; i < vars.size(); ++i) {
1433 value_watcher_->SetValueWatcher(vars[i], values[i]);
1434 }
1435 }
1436 return value_watcher_;
1437 }
1438
1439 IntVar* IsDifferent(int64_t constant) override {
1440 Solver* const s = solver();
1441 if (constant == min_.Value() && value_watcher_ == nullptr) {
1442 return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1443 }
1444 if (constant == max_.Value() && value_watcher_ == nullptr) {
1445 return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1446 }
1447 if (!Contains(constant)) {
1448 return s->MakeIntConst(int64_t{1});
1449 }
1450 if (Bound() && min_.Value() == constant) {
1451 return s->MakeIntConst(int64_t{0});
1452 }
1453 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1455 if (cache != nullptr) {
1456 return cache->Var();
1457 } else {
1458 IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1459 s->Cache()->InsertExprConstantExpression(
1460 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1461 return boolvar;
1462 }
1463 }
1464
1465 IntVar* IsGreaterOrEqual(int64_t constant) override {
1466 Solver* const s = solver();
1467 if (max_.Value() < constant) {
1468 return s->MakeIntConst(int64_t{0});
1469 }
1470 if (min_.Value() >= constant) {
1471 return s->MakeIntConst(int64_t{1});
1472 }
1473 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1475 if (cache != nullptr) {
1476 return cache->Var();
1477 } else {
1478 if (bound_watcher_ == nullptr) {
1479 if (CapSub(Max(), Min()) <= 256) {
1480 solver()->SaveAndSetValue(
1481 reinterpret_cast<void**>(&bound_watcher_),
1482 reinterpret_cast<void*>(solver()->RevAlloc(
1483 new DenseUpperBoundWatcher(solver(), this))));
1484 solver()->AddConstraint(bound_watcher_);
1485 } else {
1486 solver()->SaveAndSetValue(
1487 reinterpret_cast<void**>(&bound_watcher_),
1488 reinterpret_cast<void*>(
1489 solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1490 solver()->AddConstraint(bound_watcher_);
1491 }
1492 }
1493 IntVar* const boolvar =
1494 bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1495 s->Cache()->InsertExprConstantExpression(
1496 boolvar, this, constant,
1498 return boolvar;
1499 }
1500 }
1501
1502 Constraint* SetIsGreaterOrEqual(absl::Span<const int64_t> values,
1503 const std::vector<IntVar*>& vars) {
1504 if (bound_watcher_ == nullptr) {
1505 if (CapSub(Max(), Min()) <= 256) {
1506 solver()->SaveAndSetValue(
1507 reinterpret_cast<void**>(&bound_watcher_),
1508 reinterpret_cast<void*>(solver()->RevAlloc(
1509 new DenseUpperBoundWatcher(solver(), this))));
1510 solver()->AddConstraint(bound_watcher_);
1511 } else {
1512 solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1513 reinterpret_cast<void*>(solver()->RevAlloc(
1514 new UpperBoundWatcher(solver(), this))));
1515 solver()->AddConstraint(bound_watcher_);
1516 }
1517 for (int i = 0; i < values.size(); ++i) {
1518 bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1519 }
1520 }
1521 return bound_watcher_;
1522 }
1523
1524 IntVar* IsLessOrEqual(int64_t constant) override {
1525 Solver* const s = solver();
1526 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1528 if (cache != nullptr) {
1529 return cache->Var();
1530 } else {
1531 IntVar* const boolvar =
1532 s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1533 s->Cache()->InsertExprConstantExpression(
1534 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1535 return boolvar;
1536 }
1537 }
1538
1539 void Process();
1540 void Push();
1541 void CleanInProcess();
1542 uint64_t Size() const override {
1543 if (bits_ != nullptr) return bits_->Size();
1544 return (static_cast<uint64_t>(max_.Value()) -
1545 static_cast<uint64_t>(min_.Value()) + 1);
1546 }
1547 bool Contains(int64_t v) const override {
1548 if (v < min_.Value() || v > max_.Value()) return false;
1549 return (bits_ == nullptr ? true : bits_->Contains(v));
1550 }
1551 IntVarIterator* MakeHoleIterator(bool reversible) const override;
1552 IntVarIterator* MakeDomainIterator(bool reversible) const override;
1553 int64_t OldMin() const override { return std::min(old_min_, min_.Value()); }
1554 int64_t OldMax() const override { return std::max(old_max_, max_.Value()); }
1555
1556 std::string DebugString() const override;
1557 BitSet* bitset() const { return bits_; }
1558 int VarType() const override { return DOMAIN_INT_VAR; }
1559 std::string BaseName() const override { return "IntegerVar"; }
1560
1561 friend class PlusCstDomainIntVar;
1562 friend class LinkExprAndDomainIntVar;
1563
1564 private:
1565 void CheckOldMin() {
1566 if (old_min_ > min_.Value()) {
1567 old_min_ = min_.Value();
1568 }
1569 }
1570 void CheckOldMax() {
1571 if (old_max_ < max_.Value()) {
1572 old_max_ = max_.Value();
1573 }
1574 }
1575 Rev<int64_t> min_;
1576 Rev<int64_t> max_;
1577 int64_t old_min_;
1578 int64_t old_max_;
1579 int64_t new_min_;
1580 int64_t new_max_;
1581 SimpleRevFIFO<Demon*> bound_demons_;
1582 SimpleRevFIFO<Demon*> range_demons_;
1583 SimpleRevFIFO<Demon*> domain_demons_;
1584 SimpleRevFIFO<Demon*> delayed_bound_demons_;
1585 SimpleRevFIFO<Demon*> delayed_range_demons_;
1586 SimpleRevFIFO<Demon*> delayed_domain_demons_;
1587 QueueHandler handler_;
1588 bool in_process_;
1589 BitSet* bits_;
1590 BaseValueWatcher* value_watcher_;
1591 BaseUpperBoundWatcher* bound_watcher_;
1592};
1593
1594// ----- BitSet -----
1595
1596// Return whether an integer interval [a..b] (inclusive) contains at most
1597// K values, i.e. b - a < K, in a way that's robust to overflows.
1598// For performance reasons, in opt mode it doesn't check that [a, b] is a
1599// valid interval, nor that K is nonnegative.
1600inline bool ClosedIntervalNoLargerThan(int64_t a, int64_t b, int64_t K) {
1601 DCHECK_LE(a, b);
1602 DCHECK_GE(K, 0);
1603 if (a > 0) {
1604 return a > b - K;
1605 } else {
1606 return a + K > b;
1607 }
1608}
1609
1610class SimpleBitSet : public DomainIntVar::BitSet {
1611 public:
1612 SimpleBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1613 : BitSet(s),
1614 bits_(nullptr),
1615 stamps_(nullptr),
1616 omin_(vmin),
1617 omax_(vmax),
1618 size_(vmax - vmin + 1),
1619 bsize_(BitLength64(size_.Value())) {
1620 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1621 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1622 bits_ = new uint64_t[bsize_];
1623 stamps_ = new uint64_t[bsize_];
1624 for (int i = 0; i < bsize_; ++i) {
1625 const int bs =
1626 (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1627 bits_[i] = kAllBits64 >> bs;
1628 stamps_[i] = s->stamp() - 1;
1629 }
1630 }
1631
1632 SimpleBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1633 int64_t vmin, int64_t vmax)
1634 : BitSet(s),
1635 bits_(nullptr),
1636 stamps_(nullptr),
1637 omin_(vmin),
1638 omax_(vmax),
1639 size_(sorted_values.size()),
1640 bsize_(BitLength64(vmax - vmin + 1)) {
1641 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1642 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1643 bits_ = new uint64_t[bsize_];
1644 stamps_ = new uint64_t[bsize_];
1645 for (int i = 0; i < bsize_; ++i) {
1646 bits_[i] = uint64_t{0};
1647 stamps_[i] = s->stamp() - 1;
1648 }
1649 for (int i = 0; i < sorted_values.size(); ++i) {
1650 const int64_t val = sorted_values[i];
1651 DCHECK(!bit(val));
1652 const int offset = BitOffset64(val - omin_);
1653 const int pos = BitPos64(val - omin_);
1654 bits_[offset] |= OneBit64(pos);
1655 }
1656 }
1657
1658 ~SimpleBitSet() override {
1659 delete[] bits_;
1660 delete[] stamps_;
1661 }
1662
1663 bool bit(int64_t val) const { return IsBitSet64(bits_, val - omin_); }
1664
1665 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1666 DCHECK_GE(nmin, cmin);
1667 DCHECK_LE(nmin, cmax);
1668 DCHECK_LE(cmin, cmax);
1669 DCHECK_GE(cmin, omin_);
1670 DCHECK_LE(cmax, omax_);
1671 const int64_t new_min =
1672 UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1673 omin_;
1674 const uint64_t removed_bits =
1675 BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1676 size_.Add(solver_, -removed_bits);
1677 return new_min;
1678 }
1679
1680 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1681 DCHECK_GE(nmax, cmin);
1682 DCHECK_LE(nmax, cmax);
1683 DCHECK_LE(cmin, cmax);
1684 DCHECK_GE(cmin, omin_);
1685 DCHECK_LE(cmax, omax_);
1686 const int64_t new_max =
1687 UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1688 omin_;
1689 const uint64_t removed_bits =
1690 BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1691 size_.Add(solver_, -removed_bits);
1692 return new_max;
1693 }
1694
1695 bool SetValue(int64_t val) override {
1696 DCHECK_GE(val, omin_);
1697 DCHECK_LE(val, omax_);
1698 if (bit(val)) {
1699 size_.SetValue(solver_, 1);
1700 return true;
1701 }
1702 return false;
1703 }
1704
1705 bool Contains(int64_t val) const override {
1706 DCHECK_GE(val, omin_);
1707 DCHECK_LE(val, omax_);
1708 return bit(val);
1709 }
1710
1711 bool RemoveValue(int64_t val) override {
1712 if (val < omin_ || val > omax_ || !bit(val)) {
1713 return false;
1714 }
1715 // Bitset.
1716 const int64_t val_offset = val - omin_;
1717 const int offset = BitOffset64(val_offset);
1718 const uint64_t current_stamp = solver_->stamp();
1719 if (stamps_[offset] < current_stamp) {
1720 stamps_[offset] = current_stamp;
1721 solver_->SaveValue(&bits_[offset]);
1722 }
1723 const int pos = BitPos64(val_offset);
1724 bits_[offset] &= ~OneBit64(pos);
1725 // Size.
1726 size_.Decr(solver_);
1727 // Holes.
1728 InitHoles();
1729 AddHole(val);
1730 return true;
1731 }
1732 uint64_t Size() const override { return size_.Value(); }
1733
1734 std::string DebugString() const override {
1735 std::string out;
1736 absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1737 for (int i = 0; i < bsize_; ++i) {
1738 absl::StrAppendFormat(&out, "%x", bits_[i]);
1739 }
1740 out += ")";
1741 return out;
1742 }
1743
1744 void DelayRemoveValue(int64_t val) override { removed_.push_back(val); }
1745
1746 void ApplyRemovedValues(DomainIntVar* var) override {
1747 std::sort(removed_.begin(), removed_.end());
1748 for (std::vector<int64_t>::iterator it = removed_.begin();
1749 it != removed_.end(); ++it) {
1750 var->RemoveValue(*it);
1751 }
1752 }
1753
1754 void ClearRemovedValues() override { removed_.clear(); }
1755
1756 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1757 std::string out;
1758 DCHECK(bit(min));
1759 DCHECK(bit(max));
1760 if (max != min) {
1761 int cumul = true;
1762 int64_t start_cumul = min;
1763 for (int64_t v = min + 1; v < max; ++v) {
1764 if (bit(v)) {
1765 if (!cumul) {
1766 cumul = true;
1767 start_cumul = v;
1768 }
1769 } else {
1770 if (cumul) {
1771 if (v == start_cumul + 1) {
1772 absl::StrAppendFormat(&out, "%d ", start_cumul);
1773 } else if (v == start_cumul + 2) {
1774 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1775 } else {
1776 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1777 }
1778 cumul = false;
1779 }
1780 }
1781 }
1782 if (cumul) {
1783 if (max == start_cumul + 1) {
1784 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1785 } else {
1786 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1787 }
1788 } else {
1789 absl::StrAppendFormat(&out, "%d", max);
1790 }
1791 } else {
1792 absl::StrAppendFormat(&out, "%d", min);
1793 }
1794 return out;
1795 }
1796
1797 DomainIntVar::BitSetIterator* MakeIterator() override {
1798 return new DomainIntVar::BitSetIterator(bits_, omin_);
1799 }
1800
1801 private:
1802 uint64_t* bits_;
1803 uint64_t* stamps_;
1804 const int64_t omin_;
1805 const int64_t omax_;
1806 NumericalRev<int64_t> size_;
1807 const int bsize_;
1808 std::vector<int64_t> removed_;
1809};
1810
1811// This is a special case where the bitset fits into one 64 bit integer.
1812// In that case, there are no offset to compute.
1813// Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1814class SmallBitSet : public DomainIntVar::BitSet {
1815 public:
1816 SmallBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1817 : BitSet(s),
1818 bits_(uint64_t{0}),
1819 stamp_(s->stamp() - 1),
1820 omin_(vmin),
1821 omax_(vmax),
1822 size_(vmax - vmin + 1) {
1823 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1824 bits_ = OneRange64(0, size_.Value() - 1);
1825 }
1826
1827 SmallBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1828 int64_t vmin, int64_t vmax)
1829 : BitSet(s),
1830 bits_(uint64_t{0}),
1831 stamp_(s->stamp() - 1),
1832 omin_(vmin),
1833 omax_(vmax),
1834 size_(sorted_values.size()) {
1835 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1836 // We know the array is sorted and does not contains duplicate values.
1837 for (int i = 0; i < sorted_values.size(); ++i) {
1838 const int64_t val = sorted_values[i];
1839 DCHECK_GE(val, vmin);
1840 DCHECK_LE(val, vmax);
1841 DCHECK(!IsBitSet64(&bits_, val - omin_));
1842 bits_ |= OneBit64(val - omin_);
1843 }
1844 }
1845
1846 ~SmallBitSet() override {}
1847
1848 bool bit(int64_t val) const {
1849 DCHECK_GE(val, omin_);
1850 DCHECK_LE(val, omax_);
1851 return (bits_ & OneBit64(val - omin_)) != 0;
1852 }
1853
1854 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1855 DCHECK_GE(nmin, cmin);
1856 DCHECK_LE(nmin, cmax);
1857 DCHECK_LE(cmin, cmax);
1858 DCHECK_GE(cmin, omin_);
1859 DCHECK_LE(cmax, omax_);
1860 // We do not clean the bits between cmin and nmin.
1861 // But we use mask to look only at 'active' bits.
1862
1863 // Create the mask and compute new bits
1864 const uint64_t new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1865 if (new_bits != uint64_t{0}) {
1866 // Compute new size and new min
1867 size_.SetValue(solver_, BitCount64(new_bits));
1868 if (bit(nmin)) { // Common case, the new min is inside the bitset
1869 return nmin;
1870 }
1871 return LeastSignificantBitPosition64(new_bits) + omin_;
1872 } else { // == 0 -> Fail()
1873 solver_->Fail();
1874 return std::numeric_limits<int64_t>::max();
1875 }
1876 }
1877
1878 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1879 DCHECK_GE(nmax, cmin);
1880 DCHECK_LE(nmax, cmax);
1881 DCHECK_LE(cmin, cmax);
1882 DCHECK_GE(cmin, omin_);
1883 DCHECK_LE(cmax, omax_);
1884 // We do not clean the bits between nmax and cmax.
1885 // But we use mask to look only at 'active' bits.
1886
1887 // Create the mask and compute new_bits
1888 const uint64_t new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1889 if (new_bits != uint64_t{0}) {
1890 // Compute new size and new min
1891 size_.SetValue(solver_, BitCount64(new_bits));
1892 if (bit(nmax)) { // Common case, the new max is inside the bitset
1893 return nmax;
1894 }
1895 return MostSignificantBitPosition64(new_bits) + omin_;
1896 } else { // == 0 -> Fail()
1897 solver_->Fail();
1898 return std::numeric_limits<int64_t>::min();
1899 }
1900 }
1901
1902 bool SetValue(int64_t val) override {
1903 DCHECK_GE(val, omin_);
1904 DCHECK_LE(val, omax_);
1905 // We do not clean the bits. We will use masks to ignore the bits
1906 // that should have been cleaned.
1907 if (bit(val)) {
1908 size_.SetValue(solver_, 1);
1909 return true;
1910 }
1911 return false;
1912 }
1913
1914 bool Contains(int64_t val) const override {
1915 DCHECK_GE(val, omin_);
1916 DCHECK_LE(val, omax_);
1917 return bit(val);
1918 }
1919
1920 bool RemoveValue(int64_t val) override {
1921 DCHECK_GE(val, omin_);
1922 DCHECK_LE(val, omax_);
1923 if (bit(val)) {
1924 // Bitset.
1925 const uint64_t current_stamp = solver_->stamp();
1926 if (stamp_ < current_stamp) {
1927 stamp_ = current_stamp;
1928 solver_->SaveValue(&bits_);
1929 }
1930 bits_ &= ~OneBit64(val - omin_);
1931 DCHECK(!bit(val));
1932 // Size.
1933 size_.Decr(solver_);
1934 // Holes.
1935 InitHoles();
1936 AddHole(val);
1937 return true;
1938 } else {
1939 return false;
1940 }
1941 }
1942
1943 uint64_t Size() const override { return size_.Value(); }
1944
1945 std::string DebugString() const override {
1946 return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1947 }
1948
1949 void DelayRemoveValue(int64_t val) override {
1950 DCHECK_GE(val, omin_);
1951 DCHECK_LE(val, omax_);
1952 removed_.push_back(val);
1953 }
1954
1955 void ApplyRemovedValues(DomainIntVar* var) override {
1956 std::sort(removed_.begin(), removed_.end());
1957 for (std::vector<int64_t>::iterator it = removed_.begin();
1958 it != removed_.end(); ++it) {
1959 var->RemoveValue(*it);
1960 }
1961 }
1962
1963 void ClearRemovedValues() override { removed_.clear(); }
1964
1965 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1966 std::string out;
1967 DCHECK(bit(min));
1968 DCHECK(bit(max));
1969 if (max != min) {
1970 int cumul = true;
1971 int64_t start_cumul = min;
1972 for (int64_t v = min + 1; v < max; ++v) {
1973 if (bit(v)) {
1974 if (!cumul) {
1975 cumul = true;
1976 start_cumul = v;
1977 }
1978 } else {
1979 if (cumul) {
1980 if (v == start_cumul + 1) {
1981 absl::StrAppendFormat(&out, "%d ", start_cumul);
1982 } else if (v == start_cumul + 2) {
1983 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1984 } else {
1985 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1986 }
1987 cumul = false;
1988 }
1989 }
1990 }
1991 if (cumul) {
1992 if (max == start_cumul + 1) {
1993 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1994 } else {
1995 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1996 }
1997 } else {
1998 absl::StrAppendFormat(&out, "%d", max);
1999 }
2000 } else {
2001 absl::StrAppendFormat(&out, "%d", min);
2002 }
2003 return out;
2004 }
2005
2006 DomainIntVar::BitSetIterator* MakeIterator() override {
2007 return new DomainIntVar::BitSetIterator(&bits_, omin_);
2008 }
2009
2010 private:
2011 uint64_t bits_;
2012 uint64_t stamp_;
2013 const int64_t omin_;
2014 const int64_t omax_;
2015 NumericalRev<int64_t> size_;
2016 std::vector<int64_t> removed_;
2017};
2018
2019class EmptyIterator : public IntVarIterator {
2020 public:
2021 ~EmptyIterator() override {}
2022 void Init() override {}
2023 bool Ok() const override { return false; }
2024 int64_t Value() const override {
2025 LOG(FATAL) << "Should not be called";
2026 return 0LL;
2027 }
2028 void Next() override {}
2029};
2030
2031class RangeIterator : public IntVarIterator {
2032 public:
2033 explicit RangeIterator(const IntVar* const var)
2034 : var_(var),
2035 min_(std::numeric_limits<int64_t>::max()),
2036 max_(std::numeric_limits<int64_t>::min()),
2037 current_(-1) {}
2038
2039 ~RangeIterator() override {}
2040
2041 void Init() override {
2042 min_ = var_->Min();
2043 max_ = var_->Max();
2044 current_ = min_;
2045 }
2046
2047 bool Ok() const override { return current_ <= max_; }
2048
2049 int64_t Value() const override { return current_; }
2050
2051 void Next() override { current_++; }
2052
2053 private:
2054 const IntVar* const var_;
2055 int64_t min_;
2056 int64_t max_;
2057 int64_t current_;
2058};
2059
2060class DomainIntVarHoleIterator : public IntVarIterator {
2061 public:
2062 explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2063 : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2064
2065 ~DomainIntVarHoleIterator() override {}
2066
2067 void Init() override {
2068 bits_ = var_->bitset();
2069 if (bits_ != nullptr) {
2070 bits_->InitHoles();
2071 values_ = bits_->Holes().data();
2072 size_ = bits_->Holes().size();
2073 } else {
2074 values_ = nullptr;
2075 size_ = 0;
2076 }
2077 index_ = 0;
2078 }
2079
2080 bool Ok() const override { return index_ < size_; }
2081
2082 int64_t Value() const override {
2083 DCHECK(bits_ != nullptr);
2084 DCHECK(index_ < size_);
2085 return values_[index_];
2086 }
2087
2088 void Next() override { index_++; }
2089
2090 private:
2091 const DomainIntVar* const var_;
2092 DomainIntVar::BitSet* bits_;
2093 const int64_t* values_;
2094 int size_;
2095 int index_;
2096};
2097
2098class DomainIntVarDomainIterator : public IntVarIterator {
2099 public:
2100 explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2101 bool reversible)
2102 : var_(v),
2103 bitset_iterator_(nullptr),
2104 min_(std::numeric_limits<int64_t>::max()),
2105 max_(std::numeric_limits<int64_t>::min()),
2106 current_(-1),
2107 reversible_(reversible) {}
2108
2109 ~DomainIntVarDomainIterator() override {
2110 if (!reversible_ && bitset_iterator_) {
2111 delete bitset_iterator_;
2112 }
2113 }
2114
2115 void Init() override {
2116 if (var_->bitset() != nullptr && !var_->Bound()) {
2117 if (reversible_) {
2118 if (!bitset_iterator_) {
2119 Solver* const solver = var_->solver();
2120 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2121 bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2122 }
2123 } else {
2124 if (bitset_iterator_) {
2125 delete bitset_iterator_;
2126 }
2127 bitset_iterator_ = var_->bitset()->MakeIterator();
2128 }
2129 bitset_iterator_->Init(var_->Min(), var_->Max());
2130 } else {
2131 if (bitset_iterator_) {
2132 if (reversible_) {
2133 Solver* const solver = var_->solver();
2134 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2135 } else {
2136 delete bitset_iterator_;
2137 }
2138 bitset_iterator_ = nullptr;
2139 }
2140 min_ = var_->Min();
2141 max_ = var_->Max();
2142 current_ = min_;
2143 }
2144 }
2145
2146 bool Ok() const override {
2147 return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2148 }
2149
2150 int64_t Value() const override {
2151 return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2152 }
2153
2154 void Next() override {
2155 if (bitset_iterator_) {
2156 bitset_iterator_->Next();
2157 } else {
2158 current_++;
2159 }
2160 }
2161
2162 private:
2163 const DomainIntVar* const var_;
2164 DomainIntVar::BitSetIterator* bitset_iterator_;
2165 int64_t min_;
2166 int64_t max_;
2167 int64_t current_;
2168 const bool reversible_;
2169};
2170
2171class UnaryIterator : public IntVarIterator {
2172 public:
2173 UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2174 : iterator_(hole ? v->MakeHoleIterator(reversible)
2175 : v->MakeDomainIterator(reversible)),
2176 reversible_(reversible) {}
2177
2178 ~UnaryIterator() override {
2179 if (!reversible_) {
2180 delete iterator_;
2181 }
2182 }
2183
2184 void Init() override { iterator_->Init(); }
2185
2186 bool Ok() const override { return iterator_->Ok(); }
2187
2188 void Next() override { iterator_->Next(); }
2189
2190 protected:
2191 IntVarIterator* const iterator_;
2192 const bool reversible_;
2193};
2194
2195DomainIntVar::DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
2196 const std::string& name)
2197 : IntVar(s, name),
2198 min_(vmin),
2199 max_(vmax),
2200 old_min_(vmin),
2201 old_max_(vmax),
2202 new_min_(vmin),
2203 new_max_(vmax),
2204 handler_(this),
2205 in_process_(false),
2206 bits_(nullptr),
2207 value_watcher_(nullptr),
2208 bound_watcher_(nullptr) {}
2209
2210DomainIntVar::DomainIntVar(Solver* const s,
2211 absl::Span<const int64_t> sorted_values,
2212 const std::string& name)
2213 : IntVar(s, name),
2214 min_(std::numeric_limits<int64_t>::max()),
2215 max_(std::numeric_limits<int64_t>::min()),
2216 old_min_(std::numeric_limits<int64_t>::max()),
2217 old_max_(std::numeric_limits<int64_t>::min()),
2218 new_min_(std::numeric_limits<int64_t>::max()),
2219 new_max_(std::numeric_limits<int64_t>::min()),
2220 handler_(this),
2221 in_process_(false),
2222 bits_(nullptr),
2223 value_watcher_(nullptr),
2224 bound_watcher_(nullptr) {
2225 CHECK_GE(sorted_values.size(), 1);
2226 // We know that the vector is sorted and does not have duplicate values.
2227 const int64_t vmin = sorted_values.front();
2228 const int64_t vmax = sorted_values.back();
2229 const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2230
2231 min_.SetValue(solver(), vmin);
2232 old_min_ = vmin;
2233 new_min_ = vmin;
2234 max_.SetValue(solver(), vmax);
2235 old_max_ = vmax;
2236 new_max_ = vmax;
2237
2238 if (!contiguous) {
2239 if (vmax - vmin + 1 < 65) {
2240 bits_ = solver()->RevAlloc(
2241 new SmallBitSet(solver(), sorted_values, vmin, vmax));
2242 } else {
2243 bits_ = solver()->RevAlloc(
2244 new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2245 }
2246 }
2247}
2248
2249DomainIntVar::~DomainIntVar() {}
2250
2251void DomainIntVar::SetMin(int64_t m) {
2252 if (m <= min_.Value()) return;
2253 if (m > max_.Value()) solver()->Fail();
2254 if (in_process_) {
2255 if (m > new_min_) {
2256 new_min_ = m;
2257 if (new_min_ > new_max_) {
2258 solver()->Fail();
2259 }
2260 }
2261 } else {
2262 CheckOldMin();
2263 const int64_t new_min =
2264 (bits_ == nullptr
2265 ? m
2266 : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2267 min_.SetValue(solver(), new_min);
2268 if (min_.Value() > max_.Value()) {
2269 solver()->Fail();
2270 }
2271 Push();
2272 }
2273}
2274
2275void DomainIntVar::SetMax(int64_t m) {
2276 if (m >= max_.Value()) return;
2277 if (m < min_.Value()) solver()->Fail();
2278 if (in_process_) {
2279 if (m < new_max_) {
2280 new_max_ = m;
2281 if (new_max_ < new_min_) {
2282 solver()->Fail();
2283 }
2284 }
2285 } else {
2286 CheckOldMax();
2287 const int64_t new_max =
2288 (bits_ == nullptr
2289 ? m
2290 : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2291 max_.SetValue(solver(), new_max);
2292 if (min_.Value() > max_.Value()) {
2293 solver()->Fail();
2294 }
2295 Push();
2296 }
2297}
2298
2299void DomainIntVar::SetRange(int64_t mi, int64_t ma) {
2300 if (mi == ma) {
2301 SetValue(mi);
2302 } else {
2303 if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2304 if (mi <= min_.Value() && ma >= max_.Value()) return;
2305 if (in_process_) {
2306 if (ma < new_max_) {
2307 new_max_ = ma;
2308 }
2309 if (mi > new_min_) {
2310 new_min_ = mi;
2311 }
2312 if (new_min_ > new_max_) {
2313 solver()->Fail();
2314 }
2315 } else {
2316 if (mi > min_.Value()) {
2317 CheckOldMin();
2318 const int64_t new_min =
2319 (bits_ == nullptr
2320 ? mi
2321 : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2322 min_.SetValue(solver(), new_min);
2323 }
2324 if (min_.Value() > ma) {
2325 solver()->Fail();
2326 }
2327 if (ma < max_.Value()) {
2328 CheckOldMax();
2329 const int64_t new_max =
2330 (bits_ == nullptr
2331 ? ma
2332 : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2333 max_.SetValue(solver(), new_max);
2334 }
2335 if (min_.Value() > max_.Value()) {
2336 solver()->Fail();
2337 }
2338 Push();
2339 }
2340 }
2341}
2342
2343void DomainIntVar::SetValue(int64_t v) {
2344 if (v != min_.Value() || v != max_.Value()) {
2345 if (v < min_.Value() || v > max_.Value()) {
2346 solver()->Fail();
2347 }
2348 if (in_process_) {
2349 if (v > new_max_ || v < new_min_) {
2350 solver()->Fail();
2351 }
2352 new_min_ = v;
2353 new_max_ = v;
2354 } else {
2355 if (bits_ && !bits_->SetValue(v)) {
2356 solver()->Fail();
2357 }
2358 CheckOldMin();
2359 CheckOldMax();
2360 min_.SetValue(solver(), v);
2361 max_.SetValue(solver(), v);
2362 Push();
2363 }
2364 }
2365}
2366
2367void DomainIntVar::RemoveValue(int64_t v) {
2368 if (v < min_.Value() || v > max_.Value()) return;
2369 if (v == min_.Value()) {
2370 SetMin(v + 1);
2371 } else if (v == max_.Value()) {
2372 SetMax(v - 1);
2373 } else {
2374 if (bits_ == nullptr) {
2375 CreateBits();
2376 }
2377 if (in_process_) {
2378 if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2379 bits_->DelayRemoveValue(v);
2380 }
2381 } else {
2382 if (bits_->RemoveValue(v)) {
2383 Push();
2384 }
2385 }
2386 }
2387}
2388
2389void DomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2390 if (l <= min_.Value()) {
2391 SetMin(u + 1);
2392 } else if (u >= max_.Value()) {
2393 SetMax(l - 1);
2394 } else {
2395 for (int64_t v = l; v <= u; ++v) {
2396 RemoveValue(v);
2397 }
2398 }
2399}
2400
2401void DomainIntVar::CreateBits() {
2402 solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2403 if (max_.Value() - min_.Value() < 64) {
2404 bits_ = solver()->RevAlloc(
2405 new SmallBitSet(solver(), min_.Value(), max_.Value()));
2406 } else {
2407 bits_ = solver()->RevAlloc(
2408 new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2409 }
2410}
2411
2412void DomainIntVar::CleanInProcess() {
2413 in_process_ = false;
2414 if (bits_ != nullptr) {
2415 bits_->ClearHoles();
2416 }
2417}
2418
2419void DomainIntVar::Push() {
2420 const bool in_process = in_process_;
2421 EnqueueVar(&handler_);
2422 CHECK_EQ(in_process, in_process_);
2423}
2424
2425void DomainIntVar::Process() {
2426 CHECK(!in_process_);
2427 in_process_ = true;
2428 if (bits_ != nullptr) {
2429 bits_->ClearRemovedValues();
2430 }
2431 set_variable_to_clean_on_fail(this);
2432 new_min_ = min_.Value();
2433 new_max_ = max_.Value();
2434 const bool is_bound = min_.Value() == max_.Value();
2435 const bool range_changed =
2436 min_.Value() != OldMin() || max_.Value() != OldMax();
2437 // Process immediate demons.
2438 if (is_bound) {
2439 ExecuteAll(bound_demons_);
2440 }
2441 if (range_changed) {
2442 ExecuteAll(range_demons_);
2443 }
2444 ExecuteAll(domain_demons_);
2445
2446 // Process delayed demons.
2447 if (is_bound) {
2448 EnqueueAll(delayed_bound_demons_);
2449 }
2450 if (range_changed) {
2451 EnqueueAll(delayed_range_demons_);
2452 }
2453 EnqueueAll(delayed_domain_demons_);
2454
2455 // Everything went well if we arrive here. Let's clean the variable.
2456 set_variable_to_clean_on_fail(nullptr);
2457 CleanInProcess();
2458 old_min_ = min_.Value();
2459 old_max_ = max_.Value();
2460 if (min_.Value() < new_min_) {
2461 SetMin(new_min_);
2462 }
2463 if (max_.Value() > new_max_) {
2464 SetMax(new_max_);
2465 }
2466 if (bits_ != nullptr) {
2467 bits_->ApplyRemovedValues(this);
2468 }
2469}
2470
2471template <typename T>
2472T* CondRevAlloc(Solver* solver, bool reversible, T* object) {
2473 return reversible ? solver->RevAlloc(object) : object;
2474}
2475
2476IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2477 return CondRevAlloc(solver(), reversible, new DomainIntVarHoleIterator(this));
2478}
2479
2480IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2481 return CondRevAlloc(solver(), reversible,
2482 new DomainIntVarDomainIterator(this, reversible));
2483}
2484
2485std::string DomainIntVar::DebugString() const {
2486 std::string out;
2487 const std::string& var_name = name();
2488 if (!var_name.empty()) {
2489 out = var_name + "(";
2490 } else {
2491 out = "DomainIntVar(";
2492 }
2493 if (min_.Value() == max_.Value()) {
2494 absl::StrAppendFormat(&out, "%d", min_.Value());
2495 } else if (bits_ != nullptr) {
2496 out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2497 } else {
2498 absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2499 }
2500 out += ")";
2501 return out;
2502}
2503
2504// ----- Real Boolean Var -----
2505
2506class ConcreteBooleanVar : public BooleanVar {
2507 public:
2508 // Utility classes
2509 class Handler : public Demon {
2510 public:
2511 explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2512 ~Handler() override {}
2513 void Run(Solver* const s) override {
2514 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2515 var_->Process();
2516 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2517 }
2518 Solver::DemonPriority priority() const override {
2519 return Solver::VAR_PRIORITY;
2520 }
2521 std::string DebugString() const override {
2522 return absl::StrFormat("Handler(%s)", var_->DebugString());
2523 }
2524
2525 private:
2526 ConcreteBooleanVar* const var_;
2527 };
2528
2529 ConcreteBooleanVar(Solver* const s, const std::string& name)
2530 : BooleanVar(s, name), handler_(this) {}
2531
2532 ~ConcreteBooleanVar() override {}
2533
2534 void SetValue(int64_t v) override {
2535 if (value_ == kUnboundBooleanVarValue) {
2536 if ((v & 0xfffffffffffffffe) == 0) {
2537 InternalSaveBooleanVarValue(solver(), this);
2538 value_ = static_cast<int>(v);
2539 EnqueueVar(&handler_);
2540 return;
2541 }
2542 } else if (v == value_) {
2543 return;
2544 }
2545 solver()->Fail();
2546 }
2547
2548 void Process() {
2549 DCHECK_NE(value_, kUnboundBooleanVarValue);
2550 ExecuteAll(bound_demons_);
2551 for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2552 ++it) {
2553 EnqueueDelayedDemon(*it);
2554 }
2555 }
2556
2557 int64_t OldMin() const override { return 0LL; }
2558 int64_t OldMax() const override { return 1LL; }
2559 void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2560
2561 private:
2562 Handler handler_;
2563};
2564
2565// ----- IntConst -----
2566
2567class IntConst : public IntVar {
2568 public:
2569 IntConst(Solver* const s, int64_t value, const std::string& name = "")
2570 : IntVar(s, name), value_(value) {}
2571 ~IntConst() override {}
2572
2573 int64_t Min() const override { return value_; }
2574 void SetMin(int64_t m) override {
2575 if (m > value_) {
2576 solver()->Fail();
2577 }
2578 }
2579 int64_t Max() const override { return value_; }
2580 void SetMax(int64_t m) override {
2581 if (m < value_) {
2582 solver()->Fail();
2583 }
2584 }
2585 void SetRange(int64_t l, int64_t u) override {
2586 if (l > value_ || u < value_) {
2587 solver()->Fail();
2588 }
2589 }
2590 void SetValue(int64_t v) override {
2591 if (v != value_) {
2592 solver()->Fail();
2593 }
2594 }
2595 bool Bound() const override { return true; }
2596 int64_t Value() const override { return value_; }
2597 void RemoveValue(int64_t v) override {
2598 if (v == value_) {
2599 solver()->Fail();
2600 }
2601 }
2602 void RemoveInterval(int64_t l, int64_t u) override {
2603 if (l <= value_ && value_ <= u) {
2604 solver()->Fail();
2605 }
2606 }
2607 void WhenBound(Demon* d) override {}
2608 void WhenRange(Demon* d) override {}
2609 void WhenDomain(Demon* d) override {}
2610 uint64_t Size() const override { return 1; }
2611 bool Contains(int64_t v) const override { return (v == value_); }
2612 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2613 return CondRevAlloc(solver(), reversible, new EmptyIterator());
2614 }
2615 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2616 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
2617 }
2618 int64_t OldMin() const override { return value_; }
2619 int64_t OldMax() const override { return value_; }
2620 std::string DebugString() const override {
2621 std::string out;
2622 if (solver()->HasName(this)) {
2623 const std::string& var_name = name();
2624 absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2625 } else {
2626 absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2627 }
2628 return out;
2629 }
2630
2631 int VarType() const override { return CONST_VAR; }
2632
2633 IntVar* IsEqual(int64_t constant) override {
2634 if (constant == value_) {
2635 return solver()->MakeIntConst(1);
2636 } else {
2637 return solver()->MakeIntConst(0);
2638 }
2639 }
2640
2641 IntVar* IsDifferent(int64_t constant) override {
2642 if (constant == value_) {
2643 return solver()->MakeIntConst(0);
2644 } else {
2645 return solver()->MakeIntConst(1);
2646 }
2647 }
2648
2649 IntVar* IsGreaterOrEqual(int64_t constant) override {
2650 return solver()->MakeIntConst(value_ >= constant);
2651 }
2652
2653 IntVar* IsLessOrEqual(int64_t constant) override {
2654 return solver()->MakeIntConst(value_ <= constant);
2655 }
2656
2657 std::string name() const override {
2658 if (solver()->HasName(this)) {
2659 return PropagationBaseObject::name();
2660 } else {
2661 return absl::StrCat(value_);
2662 }
2663 }
2664
2665 private:
2666 int64_t value_;
2667};
2668
2669// ----- x + c variable, optimized case -----
2670
2671class PlusCstVar : public IntVar {
2672 public:
2673 PlusCstVar(Solver* const s, IntVar* v, int64_t c)
2674 : IntVar(s), var_(v), cst_(c) {}
2675
2676 ~PlusCstVar() override {}
2677
2678 void WhenRange(Demon* d) override { var_->WhenRange(d); }
2679
2680 void WhenBound(Demon* d) override { var_->WhenBound(d); }
2681
2682 void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2683
2684 int64_t OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2685
2686 int64_t OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2687
2688 std::string DebugString() const override {
2689 if (HasName()) {
2690 return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2691 } else {
2692 return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2693 }
2694 }
2695
2696 int VarType() const override { return VAR_ADD_CST; }
2697
2698 void Accept(ModelVisitor* const visitor) const override {
2699 visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2700 var_);
2701 }
2702
2703 IntVar* IsEqual(int64_t constant) override {
2704 return var_->IsEqual(constant - cst_);
2705 }
2706
2707 IntVar* IsDifferent(int64_t constant) override {
2708 return var_->IsDifferent(constant - cst_);
2709 }
2710
2711 IntVar* IsGreaterOrEqual(int64_t constant) override {
2712 return var_->IsGreaterOrEqual(constant - cst_);
2713 }
2714
2715 IntVar* IsLessOrEqual(int64_t constant) override {
2716 return var_->IsLessOrEqual(constant - cst_);
2717 }
2718
2719 IntVar* SubVar() const { return var_; }
2720
2721 int64_t Constant() const { return cst_; }
2722
2723 protected:
2724 IntVar* const var_;
2725 const int64_t cst_;
2726};
2727
2728class PlusCstIntVar : public PlusCstVar {
2729 public:
2730 class PlusCstIntVarIterator : public UnaryIterator {
2731 public:
2732 PlusCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2733 : UnaryIterator(v, hole, rev), cst_(c) {}
2734
2735 ~PlusCstIntVarIterator() override {}
2736
2737 int64_t Value() const override { return iterator_->Value() + cst_; }
2738
2739 private:
2740 const int64_t cst_;
2741 };
2742
2743 PlusCstIntVar(Solver* const s, IntVar* v, int64_t c) : PlusCstVar(s, v, c) {}
2744
2745 ~PlusCstIntVar() override {}
2746
2747 int64_t Min() const override { return var_->Min() + cst_; }
2748
2749 void SetMin(int64_t m) override { var_->SetMin(CapSub(m, cst_)); }
2750
2751 int64_t Max() const override { return var_->Max() + cst_; }
2752
2753 void SetMax(int64_t m) override { var_->SetMax(CapSub(m, cst_)); }
2754
2755 void SetRange(int64_t l, int64_t u) override {
2756 var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2757 }
2758
2759 void SetValue(int64_t v) override { var_->SetValue(v - cst_); }
2760
2761 int64_t Value() const override { return var_->Value() + cst_; }
2762
2763 bool Bound() const override { return var_->Bound(); }
2764
2765 void RemoveValue(int64_t v) override { var_->RemoveValue(v - cst_); }
2766
2767 void RemoveInterval(int64_t l, int64_t u) override {
2768 var_->RemoveInterval(l - cst_, u - cst_);
2769 }
2770
2771 uint64_t Size() const override { return var_->Size(); }
2772
2773 bool Contains(int64_t v) const override { return var_->Contains(v - cst_); }
2774
2775 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2776 return CondRevAlloc(
2777 solver(), reversible,
2778 new PlusCstIntVarIterator(var_, cst_, true, reversible));
2779 }
2780 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2781 return CondRevAlloc(
2782 solver(), reversible,
2783 new PlusCstIntVarIterator(var_, cst_, false, reversible));
2784 }
2785};
2786
2787class PlusCstDomainIntVar : public PlusCstVar {
2788 public:
2789 class PlusCstDomainIntVarIterator : public UnaryIterator {
2790 public:
2791 PlusCstDomainIntVarIterator(const IntVar* const v, int64_t c, bool hole,
2792 bool reversible)
2793 : UnaryIterator(v, hole, reversible), cst_(c) {}
2794
2795 ~PlusCstDomainIntVarIterator() override {}
2796
2797 int64_t Value() const override { return iterator_->Value() + cst_; }
2798
2799 private:
2800 const int64_t cst_;
2801 };
2802
2803 PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64_t c)
2804 : PlusCstVar(s, v, c) {}
2805
2806 ~PlusCstDomainIntVar() override {}
2807
2808 int64_t Min() const override;
2809 void SetMin(int64_t m) override;
2810 int64_t Max() const override;
2811 void SetMax(int64_t m) override;
2812 void SetRange(int64_t l, int64_t u) override;
2813 void SetValue(int64_t v) override;
2814 bool Bound() const override;
2815 int64_t Value() const override;
2816 void RemoveValue(int64_t v) override;
2817 void RemoveInterval(int64_t l, int64_t u) override;
2818 uint64_t Size() const override;
2819 bool Contains(int64_t v) const override;
2820
2821 DomainIntVar* domain_int_var() const {
2822 return reinterpret_cast<DomainIntVar*>(var_);
2823 }
2824
2825 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2826 return CondRevAlloc(
2827 solver(), reversible,
2828 new PlusCstDomainIntVarIterator(var_, cst_, true, reversible));
2829 }
2830 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2831 return CondRevAlloc(
2832 solver(), reversible,
2833 new PlusCstDomainIntVarIterator(var_, cst_, false, reversible));
2834 }
2835};
2836
2837int64_t PlusCstDomainIntVar::Min() const {
2838 return domain_int_var()->min_.Value() + cst_;
2839}
2840
2841void PlusCstDomainIntVar::SetMin(int64_t m) {
2842 domain_int_var()->DomainIntVar::SetMin(CapSub(m, cst_));
2843}
2844
2845int64_t PlusCstDomainIntVar::Max() const {
2846 return domain_int_var()->max_.Value() + cst_;
2847}
2848
2849void PlusCstDomainIntVar::SetMax(int64_t m) {
2850 domain_int_var()->DomainIntVar::SetMax(CapSub(m, cst_));
2851}
2852
2853void PlusCstDomainIntVar::SetRange(int64_t l, int64_t u) {
2854 domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2855}
2856
2857void PlusCstDomainIntVar::SetValue(int64_t v) {
2858 domain_int_var()->DomainIntVar::SetValue(v - cst_);
2859}
2860
2861bool PlusCstDomainIntVar::Bound() const {
2862 return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2863}
2864
2865int64_t PlusCstDomainIntVar::Value() const {
2866 CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2867 << " variable is not bound";
2868 return domain_int_var()->min_.Value() + cst_;
2869}
2870
2871void PlusCstDomainIntVar::RemoveValue(int64_t v) {
2872 domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2873}
2874
2875void PlusCstDomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2876 domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2877}
2878
2879uint64_t PlusCstDomainIntVar::Size() const {
2880 return domain_int_var()->DomainIntVar::Size();
2881}
2882
2883bool PlusCstDomainIntVar::Contains(int64_t v) const {
2884 return domain_int_var()->DomainIntVar::Contains(v - cst_);
2885}
2886
2887// c - x variable, optimized case
2888
2889class SubCstIntVar : public IntVar {
2890 public:
2891 class SubCstIntVarIterator : public UnaryIterator {
2892 public:
2893 SubCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2894 : UnaryIterator(v, hole, rev), cst_(c) {}
2895 ~SubCstIntVarIterator() override {}
2896
2897 int64_t Value() const override { return cst_ - iterator_->Value(); }
2898
2899 private:
2900 const int64_t cst_;
2901 };
2902
2903 SubCstIntVar(Solver* s, IntVar* v, int64_t c);
2904 ~SubCstIntVar() override;
2905
2906 int64_t Min() const override;
2907 void SetMin(int64_t m) override;
2908 int64_t Max() const override;
2909 void SetMax(int64_t m) override;
2910 void SetRange(int64_t l, int64_t u) override;
2911 void SetValue(int64_t v) override;
2912 bool Bound() const override;
2913 int64_t Value() const override;
2914 void RemoveValue(int64_t v) override;
2915 void RemoveInterval(int64_t l, int64_t u) override;
2916 uint64_t Size() const override;
2917 bool Contains(int64_t v) const override;
2918 void WhenRange(Demon* d) override;
2919 void WhenBound(Demon* d) override;
2920 void WhenDomain(Demon* d) override;
2921 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2922 return CondRevAlloc(solver(), reversible,
2923 new SubCstIntVarIterator(var_, cst_, true, reversible));
2924 }
2925 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2926 return CondRevAlloc(
2927 solver(), reversible,
2928 new SubCstIntVarIterator(var_, cst_, false, reversible));
2929 }
2930 int64_t OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2931 int64_t OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2932 std::string DebugString() const override;
2933 std::string name() const override;
2934 int VarType() const override { return CST_SUB_VAR; }
2935
2936 void Accept(ModelVisitor* const visitor) const override {
2937 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2938 cst_, var_);
2939 }
2940
2941 IntVar* IsEqual(int64_t constant) override {
2942 return var_->IsEqual(cst_ - constant);
2943 }
2944
2945 IntVar* IsDifferent(int64_t constant) override {
2946 return var_->IsDifferent(cst_ - constant);
2947 }
2948
2949 IntVar* IsGreaterOrEqual(int64_t constant) override {
2950 return var_->IsLessOrEqual(cst_ - constant);
2951 }
2952
2953 IntVar* IsLessOrEqual(int64_t constant) override {
2954 return var_->IsGreaterOrEqual(cst_ - constant);
2955 }
2956
2957 IntVar* SubVar() const { return var_; }
2958 int64_t Constant() const { return cst_; }
2959
2960 private:
2961 IntVar* const var_;
2962 const int64_t cst_;
2963};
2964
2965SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64_t c)
2966 : IntVar(s), var_(v), cst_(c) {}
2967
2968SubCstIntVar::~SubCstIntVar() {}
2969
2970int64_t SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2971
2972void SubCstIntVar::SetMin(int64_t m) { var_->SetMax(CapSub(cst_, m)); }
2973
2974int64_t SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2975
2976void SubCstIntVar::SetMax(int64_t m) { var_->SetMin(CapSub(cst_, m)); }
2977
2978void SubCstIntVar::SetRange(int64_t l, int64_t u) {
2979 var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2980}
2981
2982void SubCstIntVar::SetValue(int64_t v) { var_->SetValue(cst_ - v); }
2983
2984bool SubCstIntVar::Bound() const { return var_->Bound(); }
2985
2986void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2987
2988int64_t SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2989
2990void SubCstIntVar::RemoveValue(int64_t v) { var_->RemoveValue(cst_ - v); }
2991
2992void SubCstIntVar::RemoveInterval(int64_t l, int64_t u) {
2993 var_->RemoveInterval(cst_ - u, cst_ - l);
2994}
2995
2996void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2997
2998void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2999
3000uint64_t SubCstIntVar::Size() const { return var_->Size(); }
3001
3002bool SubCstIntVar::Contains(int64_t v) const {
3003 return var_->Contains(cst_ - v);
3004}
3005
3006std::string SubCstIntVar::DebugString() const {
3007 if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3008 return absl::StrFormat("Not(%s)", var_->DebugString());
3009 } else {
3010 return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
3011 }
3012}
3013
3014std::string SubCstIntVar::name() const {
3015 if (solver()->HasName(this)) {
3016 return PropagationBaseObject::name();
3017 } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3018 return absl::StrFormat("Not(%s)", var_->name());
3019 } else {
3020 return absl::StrFormat("(%d - %s)", cst_, var_->name());
3021 }
3022}
3023
3024// -x variable, optimized case
3025
3026class OppIntVar : public IntVar {
3027 public:
3028 class OppIntVarIterator : public UnaryIterator {
3029 public:
3030 OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3031 : UnaryIterator(v, hole, reversible) {}
3032 ~OppIntVarIterator() override {}
3033
3034 int64_t Value() const override { return -iterator_->Value(); }
3035 };
3036
3037 OppIntVar(Solver* s, IntVar* v);
3038 ~OppIntVar() override;
3039
3040 int64_t Min() const override;
3041 void SetMin(int64_t m) override;
3042 int64_t Max() const override;
3043 void SetMax(int64_t m) override;
3044 void SetRange(int64_t l, int64_t u) override;
3045 void SetValue(int64_t v) override;
3046 bool Bound() const override;
3047 int64_t Value() const override;
3048 void RemoveValue(int64_t v) override;
3049 void RemoveInterval(int64_t l, int64_t u) override;
3050 uint64_t Size() const override;
3051 bool Contains(int64_t v) const override;
3052 void WhenRange(Demon* d) override;
3053 void WhenBound(Demon* d) override;
3054 void WhenDomain(Demon* d) override;
3055 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3056 return CondRevAlloc(solver(), reversible,
3057 new OppIntVarIterator(var_, true, reversible));
3058 }
3059 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3060 return CondRevAlloc(solver(), reversible,
3061 new OppIntVarIterator(var_, false, reversible));
3062 }
3063 int64_t OldMin() const override { return CapOpp(var_->OldMax()); }
3064 int64_t OldMax() const override { return CapOpp(var_->OldMin()); }
3065 std::string DebugString() const override;
3066 int VarType() const override { return OPP_VAR; }
3067
3068 void Accept(ModelVisitor* const visitor) const override {
3069 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3070 var_);
3071 }
3072
3073 IntVar* IsEqual(int64_t constant) override {
3074 return var_->IsEqual(-constant);
3075 }
3076
3077 IntVar* IsDifferent(int64_t constant) override {
3078 return var_->IsDifferent(-constant);
3079 }
3080
3081 IntVar* IsGreaterOrEqual(int64_t constant) override {
3082 return var_->IsLessOrEqual(-constant);
3083 }
3084
3085 IntVar* IsLessOrEqual(int64_t constant) override {
3086 return var_->IsGreaterOrEqual(-constant);
3087 }
3088
3089 IntVar* SubVar() const { return var_; }
3090
3091 private:
3092 IntVar* const var_;
3093};
3094
3095OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3096
3097OppIntVar::~OppIntVar() {}
3098
3099int64_t OppIntVar::Min() const { return -var_->Max(); }
3100
3101void OppIntVar::SetMin(int64_t m) { var_->SetMax(CapOpp(m)); }
3102
3103int64_t OppIntVar::Max() const { return -var_->Min(); }
3104
3105void OppIntVar::SetMax(int64_t m) { var_->SetMin(CapOpp(m)); }
3106
3107void OppIntVar::SetRange(int64_t l, int64_t u) {
3108 var_->SetRange(CapOpp(u), CapOpp(l));
3109}
3110
3111void OppIntVar::SetValue(int64_t v) { var_->SetValue(CapOpp(v)); }
3112
3113bool OppIntVar::Bound() const { return var_->Bound(); }
3114
3115void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3116
3117int64_t OppIntVar::Value() const { return -var_->Value(); }
3118
3119void OppIntVar::RemoveValue(int64_t v) { var_->RemoveValue(-v); }
3120
3121void OppIntVar::RemoveInterval(int64_t l, int64_t u) {
3122 var_->RemoveInterval(-u, -l);
3123}
3124
3125void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3126
3127void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3128
3129uint64_t OppIntVar::Size() const { return var_->Size(); }
3130
3131bool OppIntVar::Contains(int64_t v) const { return var_->Contains(-v); }
3132
3133std::string OppIntVar::DebugString() const {
3134 return absl::StrFormat("-(%s)", var_->DebugString());
3135}
3136
3137// ----- Utility functions -----
3138
3139// x * c variable, optimized case
3140
3141class TimesCstIntVar : public IntVar {
3142 public:
3143 TimesCstIntVar(Solver* const s, IntVar* v, int64_t c)
3144 : IntVar(s), var_(v), cst_(c) {}
3145 ~TimesCstIntVar() override {}
3146
3147 IntVar* SubVar() const { return var_; }
3148 int64_t Constant() const { return cst_; }
3149
3150 void Accept(ModelVisitor* const visitor) const override {
3151 visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3152 var_);
3153 }
3154
3155 IntVar* IsEqual(int64_t constant) override {
3156 if (constant % cst_ == 0) {
3157 return var_->IsEqual(constant / cst_);
3158 } else {
3159 return solver()->MakeIntConst(0);
3160 }
3161 }
3162
3163 IntVar* IsDifferent(int64_t constant) override {
3164 if (constant % cst_ == 0) {
3165 return var_->IsDifferent(constant / cst_);
3166 } else {
3167 return solver()->MakeIntConst(1);
3168 }
3169 }
3170
3171 IntVar* IsGreaterOrEqual(int64_t constant) override {
3172 if (cst_ > 0) {
3173 return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3174 } else {
3175 return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3176 }
3177 }
3178
3179 IntVar* IsLessOrEqual(int64_t constant) override {
3180 if (cst_ > 0) {
3181 return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3182 } else {
3183 return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3184 }
3185 }
3186
3187 std::string DebugString() const override {
3188 return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3189 }
3190
3191 int VarType() const override { return VAR_TIMES_CST; }
3192
3193 protected:
3194 IntVar* const var_;
3195 const int64_t cst_;
3196};
3197
3198class TimesPosCstIntVar : public TimesCstIntVar {
3199 public:
3200 class TimesPosCstIntVarIterator : public UnaryIterator {
3201 public:
3202 TimesPosCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3203 bool reversible)
3204 : UnaryIterator(v, hole, reversible), cst_(c) {}
3205 ~TimesPosCstIntVarIterator() override {}
3206
3207 int64_t Value() const override { return iterator_->Value() * cst_; }
3208
3209 private:
3210 const int64_t cst_;
3211 };
3212
3213 TimesPosCstIntVar(Solver* s, IntVar* v, int64_t c);
3214 ~TimesPosCstIntVar() override;
3215
3216 int64_t Min() const override;
3217 void SetMin(int64_t m) override;
3218 int64_t Max() const override;
3219 void SetMax(int64_t m) override;
3220 void SetRange(int64_t l, int64_t u) override;
3221 void SetValue(int64_t v) override;
3222 bool Bound() const override;
3223 int64_t Value() const override;
3224 void RemoveValue(int64_t v) override;
3225 void RemoveInterval(int64_t l, int64_t u) override;
3226 uint64_t Size() const override;
3227 bool Contains(int64_t v) const override;
3228 void WhenRange(Demon* d) override;
3229 void WhenBound(Demon* d) override;
3230 void WhenDomain(Demon* d) override;
3231 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3232 return CondRevAlloc(
3233 solver(), reversible,
3234 new TimesPosCstIntVarIterator(var_, cst_, true, reversible));
3235 }
3236 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3237 return CondRevAlloc(
3238 solver(), reversible,
3239 new TimesPosCstIntVarIterator(var_, cst_, false, reversible));
3240 }
3241 int64_t OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3242 int64_t OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3243};
3244
3245// ----- TimesPosCstIntVar -----
3246
3247TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c)
3248 : TimesCstIntVar(s, v, c) {}
3249
3250TimesPosCstIntVar::~TimesPosCstIntVar() {}
3251
3252int64_t TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3253
3254void TimesPosCstIntVar::SetMin(int64_t m) {
3255 if (m != std::numeric_limits<int64_t>::min()) {
3256 var_->SetMin(PosIntDivUp(m, cst_));
3257 }
3258}
3259
3260int64_t TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3261
3262void TimesPosCstIntVar::SetMax(int64_t m) {
3263 if (m != std::numeric_limits<int64_t>::max()) {
3264 var_->SetMax(PosIntDivDown(m, cst_));
3265 }
3266}
3267
3268void TimesPosCstIntVar::SetRange(int64_t l, int64_t u) {
3269 var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3270}
3271
3272void TimesPosCstIntVar::SetValue(int64_t v) {
3273 if (v % cst_ != 0) {
3274 solver()->Fail();
3275 }
3276 var_->SetValue(v / cst_);
3277}
3278
3279bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3280
3281void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3282
3283int64_t TimesPosCstIntVar::Value() const {
3284 return CapProd(var_->Value(), cst_);
3285}
3286
3287void TimesPosCstIntVar::RemoveValue(int64_t v) {
3288 if (v % cst_ == 0) {
3289 var_->RemoveValue(v / cst_);
3290 }
3291}
3292
3293void TimesPosCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3294 for (int64_t v = l; v <= u; ++v) {
3295 RemoveValue(v);
3296 }
3297 // TODO(user) : Improve me
3298}
3299
3300void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3301
3302void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3303
3304uint64_t TimesPosCstIntVar::Size() const { return var_->Size(); }
3305
3306bool TimesPosCstIntVar::Contains(int64_t v) const {
3307 return (v % cst_ == 0 && var_->Contains(v / cst_));
3308}
3309
3310// b * c variable, optimized case
3311
3312class TimesPosCstBoolVar : public TimesCstIntVar {
3313 public:
3314 class TimesPosCstBoolVarIterator : public UnaryIterator {
3315 public:
3316 // TODO(user) : optimize this.
3317 TimesPosCstBoolVarIterator(const IntVar* const v, int64_t c, bool hole,
3318 bool reversible)
3319 : UnaryIterator(v, hole, reversible), cst_(c) {}
3320 ~TimesPosCstBoolVarIterator() override {}
3321
3322 int64_t Value() const override { return iterator_->Value() * cst_; }
3323
3324 private:
3325 const int64_t cst_;
3326 };
3327
3328 TimesPosCstBoolVar(Solver* s, BooleanVar* v, int64_t c);
3329 ~TimesPosCstBoolVar() override;
3330
3331 int64_t Min() const override;
3332 void SetMin(int64_t m) override;
3333 int64_t Max() const override;
3334 void SetMax(int64_t m) override;
3335 void SetRange(int64_t l, int64_t u) override;
3336 void SetValue(int64_t v) override;
3337 bool Bound() const override;
3338 int64_t Value() const override;
3339 void RemoveValue(int64_t v) override;
3340 void RemoveInterval(int64_t l, int64_t u) override;
3341 uint64_t Size() const override;
3342 bool Contains(int64_t v) const override;
3343 void WhenRange(Demon* d) override;
3344 void WhenBound(Demon* d) override;
3345 void WhenDomain(Demon* d) override;
3346 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3347 return CondRevAlloc(solver(), reversible, new EmptyIterator());
3348 }
3349 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3350 return CondRevAlloc(
3351 solver(), reversible,
3352 new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3353 }
3354 int64_t OldMin() const override { return 0; }
3355 int64_t OldMax() const override { return cst_; }
3356
3357 BooleanVar* boolean_var() const {
3358 return reinterpret_cast<BooleanVar*>(var_);
3359 }
3360};
3361
3362// ----- TimesPosCstBoolVar -----
3363
3364TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v,
3365 int64_t c)
3366 : TimesCstIntVar(s, v, c) {}
3367
3368TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3369
3370int64_t TimesPosCstBoolVar::Min() const {
3371 return (boolean_var()->RawValue() == 1) * cst_;
3372}
3373
3374void TimesPosCstBoolVar::SetMin(int64_t m) {
3375 if (m > cst_) {
3376 solver()->Fail();
3377 } else if (m > 0) {
3378 boolean_var()->SetMin(1);
3379 }
3380}
3381
3382int64_t TimesPosCstBoolVar::Max() const {
3383 return (boolean_var()->RawValue() != 0) * cst_;
3384}
3385
3386void TimesPosCstBoolVar::SetMax(int64_t m) {
3387 if (m < 0) {
3388 solver()->Fail();
3389 } else if (m < cst_) {
3390 boolean_var()->SetMax(0);
3391 }
3392}
3393
3394void TimesPosCstBoolVar::SetRange(int64_t l, int64_t u) {
3395 if (u < 0 || l > cst_ || l > u) {
3396 solver()->Fail();
3397 }
3398 if (l > 0) {
3399 boolean_var()->SetMin(1);
3400 } else if (u < cst_) {
3401 boolean_var()->SetMax(0);
3402 }
3403}
3404
3405void TimesPosCstBoolVar::SetValue(int64_t v) {
3406 if (v == 0) {
3407 boolean_var()->SetValue(0);
3408 } else if (v == cst_) {
3409 boolean_var()->SetValue(1);
3410 } else {
3411 solver()->Fail();
3412 }
3413}
3414
3415bool TimesPosCstBoolVar::Bound() const {
3416 return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3417}
3418
3419void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3420
3421int64_t TimesPosCstBoolVar::Value() const {
3422 CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3423 << " variable is not bound";
3424 return boolean_var()->RawValue() * cst_;
3425}
3426
3427void TimesPosCstBoolVar::RemoveValue(int64_t v) {
3428 if (v == 0) {
3429 boolean_var()->RemoveValue(0);
3430 } else if (v == cst_) {
3431 boolean_var()->RemoveValue(1);
3432 }
3433}
3434
3435void TimesPosCstBoolVar::RemoveInterval(int64_t l, int64_t u) {
3436 if (l <= 0 && u >= 0) {
3437 boolean_var()->RemoveValue(0);
3438 }
3439 if (l <= cst_ && u >= cst_) {
3440 boolean_var()->RemoveValue(1);
3441 }
3442}
3443
3444void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3445
3446void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3447
3448uint64_t TimesPosCstBoolVar::Size() const {
3449 return (1 +
3450 (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3451}
3452
3453bool TimesPosCstBoolVar::Contains(int64_t v) const {
3454 if (v == 0) {
3455 return boolean_var()->RawValue() != 1;
3456 } else if (v == cst_) {
3457 return boolean_var()->RawValue() != 0;
3458 }
3459 return false;
3460}
3461
3462// TimesNegCstIntVar
3463
3464class TimesNegCstIntVar : public TimesCstIntVar {
3465 public:
3466 class TimesNegCstIntVarIterator : public UnaryIterator {
3467 public:
3468 TimesNegCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3469 bool reversible)
3470 : UnaryIterator(v, hole, reversible), cst_(c) {}
3471 ~TimesNegCstIntVarIterator() override {}
3472
3473 int64_t Value() const override { return iterator_->Value() * cst_; }
3474
3475 private:
3476 const int64_t cst_;
3477 };
3478
3479 TimesNegCstIntVar(Solver* s, IntVar* v, int64_t c);
3480 ~TimesNegCstIntVar() override;
3481
3482 int64_t Min() const override;
3483 void SetMin(int64_t m) override;
3484 int64_t Max() const override;
3485 void SetMax(int64_t m) override;
3486 void SetRange(int64_t l, int64_t u) override;
3487 void SetValue(int64_t v) override;
3488 bool Bound() const override;
3489 int64_t Value() const override;
3490 void RemoveValue(int64_t v) override;
3491 void RemoveInterval(int64_t l, int64_t u) override;
3492 uint64_t Size() const override;
3493 bool Contains(int64_t v) const override;
3494 void WhenRange(Demon* d) override;
3495 void WhenBound(Demon* d) override;
3496 void WhenDomain(Demon* d) override;
3497 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3498 return CondRevAlloc(
3499 solver(), reversible,
3500 new TimesNegCstIntVarIterator(var_, cst_, true, reversible));
3501 }
3502 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3503 return CondRevAlloc(
3504 solver(), reversible,
3505 new TimesNegCstIntVarIterator(var_, cst_, false, reversible));
3506 }
3507 int64_t OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3508 int64_t OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3509};
3510
3511// ----- TimesNegCstIntVar -----
3512
3513TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c)
3514 : TimesCstIntVar(s, v, c) {}
3515
3516TimesNegCstIntVar::~TimesNegCstIntVar() {}
3517
3518int64_t TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3519
3520void TimesNegCstIntVar::SetMin(int64_t m) {
3521 if (m != std::numeric_limits<int64_t>::min()) {
3522 var_->SetMax(PosIntDivDown(-m, -cst_));
3523 }
3524}
3525
3526int64_t TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3527
3528void TimesNegCstIntVar::SetMax(int64_t m) {
3529 if (m != std::numeric_limits<int64_t>::max()) {
3530 var_->SetMin(PosIntDivUp(-m, -cst_));
3531 }
3532}
3533
3534void TimesNegCstIntVar::SetRange(int64_t l, int64_t u) {
3535 var_->SetRange(PosIntDivUp(CapOpp(u), CapOpp(cst_)),
3537}
3538
3539void TimesNegCstIntVar::SetValue(int64_t v) {
3540 if (v % cst_ != 0) {
3541 solver()->Fail();
3542 }
3543 var_->SetValue(v / cst_);
3544}
3545
3546bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3547
3548void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3549
3550int64_t TimesNegCstIntVar::Value() const {
3551 return CapProd(var_->Value(), cst_);
3552}
3553
3554void TimesNegCstIntVar::RemoveValue(int64_t v) {
3555 if (v % cst_ == 0) {
3556 var_->RemoveValue(v / cst_);
3557 }
3558}
3559
3560void TimesNegCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3561 for (int64_t v = l; v <= u; ++v) {
3562 RemoveValue(v);
3563 }
3564 // TODO(user) : Improve me
3565}
3566
3567void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3568
3569void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3570
3571uint64_t TimesNegCstIntVar::Size() const { return var_->Size(); }
3572
3573bool TimesNegCstIntVar::Contains(int64_t v) const {
3574 return (v % cst_ == 0 && var_->Contains(v / cst_));
3575}
3576
3577// ---------- arithmetic expressions ----------
3578
3579// ----- PlusIntExpr -----
3580
3581class PlusIntExpr : public BaseIntExpr {
3582 public:
3583 PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3584 : BaseIntExpr(s), left_(l), right_(r) {}
3585
3586 ~PlusIntExpr() override {}
3587
3588 int64_t Min() const override { return left_->Min() + right_->Min(); }
3589
3590 void SetMin(int64_t m) override {
3591 if (m > left_->Min() + right_->Min()) {
3592 // Catching potential overflow.
3593 if (m > right_->Max() + left_->Max()) solver()->Fail();
3594 left_->SetMin(m - right_->Max());
3595 right_->SetMin(m - left_->Max());
3596 }
3597 }
3598
3599 void SetRange(int64_t l, int64_t u) override {
3600 const int64_t left_min = left_->Min();
3601 const int64_t right_min = right_->Min();
3602 const int64_t left_max = left_->Max();
3603 const int64_t right_max = right_->Max();
3604 if (l > left_min + right_min) {
3605 // Catching potential overflow.
3606 if (l > right_max + left_max) solver()->Fail();
3607 left_->SetMin(l - right_max);
3608 right_->SetMin(l - left_max);
3609 }
3610 if (u < left_max + right_max) {
3611 // Catching potential overflow.
3612 if (u < right_min + left_min) solver()->Fail();
3613 left_->SetMax(u - right_min);
3614 right_->SetMax(u - left_min);
3615 }
3616 }
3617
3618 int64_t Max() const override { return left_->Max() + right_->Max(); }
3619
3620 void SetMax(int64_t m) override {
3621 if (m < left_->Max() + right_->Max()) {
3622 // Catching potential overflow.
3623 if (m < right_->Min() + left_->Min()) solver()->Fail();
3624 left_->SetMax(m - right_->Min());
3625 right_->SetMax(m - left_->Min());
3626 }
3627 }
3628
3629 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3630
3631 void Range(int64_t* const mi, int64_t* const ma) override {
3632 *mi = left_->Min() + right_->Min();
3633 *ma = left_->Max() + right_->Max();
3634 }
3635
3636 std::string name() const override {
3637 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3638 }
3639
3640 std::string DebugString() const override {
3641 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3642 right_->DebugString());
3643 }
3644
3645 void WhenRange(Demon* d) override {
3646 left_->WhenRange(d);
3647 right_->WhenRange(d);
3648 }
3649
3650 void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3651 PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3652 if (casted != nullptr) {
3653 ExpandPlusIntExpr(casted->left_, subs);
3654 ExpandPlusIntExpr(casted->right_, subs);
3655 } else {
3656 subs->push_back(expr);
3657 }
3658 }
3659
3660 IntVar* CastToVar() override {
3661 if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3662 dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3663 std::vector<IntExpr*> sub_exprs;
3664 ExpandPlusIntExpr(left_, &sub_exprs);
3665 ExpandPlusIntExpr(right_, &sub_exprs);
3666 if (sub_exprs.size() >= 3) {
3667 std::vector<IntVar*> sub_vars(sub_exprs.size());
3668 for (int i = 0; i < sub_exprs.size(); ++i) {
3669 sub_vars[i] = sub_exprs[i]->Var();
3670 }
3671 return solver()->MakeSum(sub_vars)->Var();
3672 }
3673 }
3674 return BaseIntExpr::CastToVar();
3675 }
3676
3677 void Accept(ModelVisitor* const visitor) const override {
3678 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3679 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3680 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3681 right_);
3682 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3683 }
3684
3685 private:
3686 IntExpr* const left_;
3687 IntExpr* const right_;
3688};
3689
3690class SafePlusIntExpr : public BaseIntExpr {
3691 public:
3692 SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3693 : BaseIntExpr(s), left_(l), right_(r) {}
3694
3695 ~SafePlusIntExpr() override {}
3696
3697 int64_t Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3698
3699 void SetMin(int64_t m) override {
3700 left_->SetMin(CapSub(m, right_->Max()));
3701 right_->SetMin(CapSub(m, left_->Max()));
3702 }
3703
3704 void SetRange(int64_t l, int64_t u) override {
3705 const int64_t left_min = left_->Min();
3706 const int64_t right_min = right_->Min();
3707 const int64_t left_max = left_->Max();
3708 const int64_t right_max = right_->Max();
3709 if (l > CapAdd(left_min, right_min)) {
3710 left_->SetMin(CapSub(l, right_max));
3711 right_->SetMin(CapSub(l, left_max));
3712 }
3713 if (u < CapAdd(left_max, right_max)) {
3714 left_->SetMax(CapSub(u, right_min));
3715 right_->SetMax(CapSub(u, left_min));
3716 }
3717 }
3718
3719 int64_t Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3720
3721 void SetMax(int64_t m) override {
3722 left_->SetMax(CapSub(m, right_->Min()));
3723 right_->SetMax(CapSub(m, left_->Min()));
3724 }
3725
3726 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3727
3728 std::string name() const override {
3729 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3730 }
3731
3732 std::string DebugString() const override {
3733 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3734 right_->DebugString());
3735 }
3736
3737 void WhenRange(Demon* d) override {
3738 left_->WhenRange(d);
3739 right_->WhenRange(d);
3740 }
3741
3742 void Accept(ModelVisitor* const visitor) const override {
3743 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3744 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3745 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3746 right_);
3747 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3748 }
3749
3750 private:
3751 IntExpr* const left_;
3752 IntExpr* const right_;
3753};
3754
3755// ----- PlusIntCstExpr -----
3756
3757class PlusIntCstExpr : public BaseIntExpr {
3758 public:
3759 PlusIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3760 : BaseIntExpr(s), expr_(e), value_(v) {}
3761 ~PlusIntCstExpr() override {}
3762 int64_t Min() const override { return CapAdd(expr_->Min(), value_); }
3763 void SetMin(int64_t m) override { expr_->SetMin(CapSub(m, value_)); }
3764 int64_t Max() const override { return CapAdd(expr_->Max(), value_); }
3765 void SetMax(int64_t m) override { expr_->SetMax(CapSub(m, value_)); }
3766 bool Bound() const override { return (expr_->Bound()); }
3767 std::string name() const override {
3768 return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3769 }
3770 std::string DebugString() const override {
3771 return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3772 }
3773 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3774 IntVar* CastToVar() override;
3775 void Accept(ModelVisitor* const visitor) const override {
3776 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3777 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3778 expr_);
3779 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3780 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3781 }
3782
3783 private:
3784 IntExpr* const expr_;
3785 const int64_t value_;
3786};
3787
3788IntVar* PlusIntCstExpr::CastToVar() {
3789 Solver* const s = solver();
3790 IntVar* const var = expr_->Var();
3791 IntVar* cast = nullptr;
3792 if (AddOverflows(value_, expr_->Max()) ||
3793 AddOverflows(value_, expr_->Min())) {
3794 return BaseIntExpr::CastToVar();
3795 }
3796 switch (var->VarType()) {
3797 case DOMAIN_INT_VAR:
3798 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3799 s, reinterpret_cast<DomainIntVar*>(var), value_)));
3800 // FIXME: Break was inserted during fallthrough cleanup. Please check.
3801 break;
3802 default:
3803 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3804 break;
3805 }
3806 return cast;
3807}
3808
3809// ----- SubIntExpr -----
3810
3811class SubIntExpr : public BaseIntExpr {
3812 public:
3813 SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3814 : BaseIntExpr(s), left_(l), right_(r) {}
3815
3816 ~SubIntExpr() override {}
3817
3818 int64_t Min() const override { return left_->Min() - right_->Max(); }
3819
3820 void SetMin(int64_t m) override {
3821 left_->SetMin(CapAdd(m, right_->Min()));
3822 right_->SetMax(CapSub(left_->Max(), m));
3823 }
3824
3825 int64_t Max() const override { return left_->Max() - right_->Min(); }
3826
3827 void SetMax(int64_t m) override {
3828 left_->SetMax(CapAdd(m, right_->Max()));
3829 right_->SetMin(CapSub(left_->Min(), m));
3830 }
3831
3832 void Range(int64_t* mi, int64_t* ma) override {
3833 *mi = left_->Min() - right_->Max();
3834 *ma = left_->Max() - right_->Min();
3835 }
3836
3837 void SetRange(int64_t l, int64_t u) override {
3838 const int64_t left_min = left_->Min();
3839 const int64_t right_min = right_->Min();
3840 const int64_t left_max = left_->Max();
3841 const int64_t right_max = right_->Max();
3842 if (l > left_min - right_max) {
3843 left_->SetMin(CapAdd(l, right_min));
3844 right_->SetMax(CapSub(left_max, l));
3845 }
3846 if (u < left_max - right_min) {
3847 left_->SetMax(CapAdd(u, right_max));
3848 right_->SetMin(CapSub(left_min, u));
3849 }
3850 }
3851
3852 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3853
3854 std::string name() const override {
3855 return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3856 }
3857
3858 std::string DebugString() const override {
3859 return absl::StrFormat("(%s - %s)", left_->DebugString(),
3860 right_->DebugString());
3861 }
3862
3863 void WhenRange(Demon* d) override {
3864 left_->WhenRange(d);
3865 right_->WhenRange(d);
3866 }
3867
3868 void Accept(ModelVisitor* const visitor) const override {
3869 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3870 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3871 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3872 right_);
3873 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3874 }
3875
3876 IntExpr* left() const { return left_; }
3877 IntExpr* right() const { return right_; }
3878
3879 protected:
3880 IntExpr* const left_;
3881 IntExpr* const right_;
3882};
3883
3884class SafeSubIntExpr : public SubIntExpr {
3885 public:
3886 SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3887 : SubIntExpr(s, l, r) {}
3888
3889 ~SafeSubIntExpr() override {}
3890
3891 int64_t Min() const override { return CapSub(left_->Min(), right_->Max()); }
3892
3893 void SetMin(int64_t m) override {
3894 left_->SetMin(CapAdd(m, right_->Min()));
3895 right_->SetMax(CapSub(left_->Max(), m));
3896 }
3897
3898 void SetRange(int64_t l, int64_t u) override {
3899 const int64_t left_min = left_->Min();
3900 const int64_t right_min = right_->Min();
3901 const int64_t left_max = left_->Max();
3902 const int64_t right_max = right_->Max();
3903 if (l > CapSub(left_min, right_max)) {
3904 left_->SetMin(CapAdd(l, right_min));
3905 right_->SetMax(CapSub(left_max, l));
3906 }
3907 if (u < CapSub(left_max, right_min)) {
3908 left_->SetMax(CapAdd(u, right_max));
3909 right_->SetMin(CapSub(left_min, u));
3910 }
3911 }
3912
3913 void Range(int64_t* mi, int64_t* ma) override {
3914 *mi = CapSub(left_->Min(), right_->Max());
3915 *ma = CapSub(left_->Max(), right_->Min());
3916 }
3917
3918 int64_t Max() const override { return CapSub(left_->Max(), right_->Min()); }
3919
3920 void SetMax(int64_t m) override {
3921 left_->SetMax(CapAdd(m, right_->Max()));
3922 right_->SetMin(CapSub(left_->Min(), m));
3923 }
3924};
3925
3926// l - r
3927
3928// ----- SubIntCstExpr -----
3929
3930class SubIntCstExpr : public BaseIntExpr {
3931 public:
3932 SubIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3933 : BaseIntExpr(s), expr_(e), value_(v) {}
3934 ~SubIntCstExpr() override {}
3935 int64_t Min() const override { return CapSub(value_, expr_->Max()); }
3936 void SetMin(int64_t m) override { expr_->SetMax(CapSub(value_, m)); }
3937 int64_t Max() const override { return CapSub(value_, expr_->Min()); }
3938 void SetMax(int64_t m) override { expr_->SetMin(CapSub(value_, m)); }
3939 bool Bound() const override { return (expr_->Bound()); }
3940 std::string name() const override {
3941 return absl::StrFormat("(%d - %s)", value_, expr_->name());
3942 }
3943 std::string DebugString() const override {
3944 return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3945 }
3946 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3947 IntVar* CastToVar() override;
3948
3949 void Accept(ModelVisitor* const visitor) const override {
3950 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3951 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3952 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3953 expr_);
3954 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3955 }
3956
3957 private:
3958 IntExpr* const expr_;
3959 const int64_t value_;
3960};
3961
3962IntVar* SubIntCstExpr::CastToVar() {
3963 if (SubOverflows(value_, expr_->Min()) ||
3964 SubOverflows(value_, expr_->Max())) {
3965 return BaseIntExpr::CastToVar();
3966 }
3967 Solver* const s = solver();
3968 IntVar* const var =
3969 s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3970 return var;
3971}
3972
3973// ----- OppIntExpr -----
3974
3975class OppIntExpr : public BaseIntExpr {
3976 public:
3977 OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3978 ~OppIntExpr() override {}
3979 int64_t Min() const override { return (CapOpp(expr_->Max())); }
3980 void SetMin(int64_t m) override { expr_->SetMax(CapOpp(m)); }
3981 int64_t Max() const override { return (CapOpp(expr_->Min())); }
3982 void SetMax(int64_t m) override { expr_->SetMin(CapOpp(m)); }
3983 bool Bound() const override { return (expr_->Bound()); }
3984 std::string name() const override {
3985 return absl::StrFormat("(-%s)", expr_->name());
3986 }
3987 std::string DebugString() const override {
3988 return absl::StrFormat("(-%s)", expr_->DebugString());
3989 }
3990 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3991 IntVar* CastToVar() override;
3992
3993 void Accept(ModelVisitor* const visitor) const override {
3994 visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3995 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3996 expr_);
3997 visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3998 }
3999
4000 private:
4001 IntExpr* const expr_;
4002};
4003
4004IntVar* OppIntExpr::CastToVar() {
4005 Solver* const s = solver();
4006 IntVar* const var =
4007 s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
4008 return var;
4009}
4010
4011// ----- TimesIntCstExpr -----
4012
4013class TimesIntCstExpr : public BaseIntExpr {
4014 public:
4015 TimesIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4016 : BaseIntExpr(s), expr_(e), value_(v) {}
4017
4018 ~TimesIntCstExpr() override {}
4019
4020 bool Bound() const override { return (expr_->Bound()); }
4021
4022 std::string name() const override {
4023 return absl::StrFormat("(%s * %d)", expr_->name(), value_);
4024 }
4025
4026 std::string DebugString() const override {
4027 return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
4028 }
4029
4030 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4031
4032 IntExpr* Expr() const { return expr_; }
4033
4034 int64_t Constant() const { return value_; }
4035
4036 void Accept(ModelVisitor* const visitor) const override {
4037 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4038 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4039 expr_);
4040 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4041 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4042 }
4043
4044 protected:
4045 IntExpr* const expr_;
4046 const int64_t value_;
4047};
4048
4049// ----- TimesPosIntCstExpr -----
4050
4051class TimesPosIntCstExpr : public TimesIntCstExpr {
4052 public:
4053 TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4054 : TimesIntCstExpr(s, e, v) {
4055 CHECK_GT(v, 0);
4056 }
4057
4058 ~TimesPosIntCstExpr() override {}
4059
4060 int64_t Min() const override { return expr_->Min() * value_; }
4061
4062 void SetMin(int64_t m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4063
4064 int64_t Max() const override { return expr_->Max() * value_; }
4065
4066 void SetMax(int64_t m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4067
4068 IntVar* CastToVar() override {
4069 Solver* const s = solver();
4070 IntVar* var = nullptr;
4071 if (expr_->IsVar() &&
4072 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4073 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4074 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4075 } else {
4076 var = s->RegisterIntVar(
4077 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4078 }
4079 return var;
4080 }
4081};
4082
4083// This expressions adds safe arithmetic (w.r.t. overflows) compared
4084// to the previous one.
4085class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4086 public:
4087 SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4088 : TimesIntCstExpr(s, e, v) {
4089 CHECK_GT(v, 0);
4090 }
4091
4092 ~SafeTimesPosIntCstExpr() override {}
4093
4094 int64_t Min() const override { return CapProd(expr_->Min(), value_); }
4095
4096 void SetMin(int64_t m) override {
4097 if (m != std::numeric_limits<int64_t>::min()) {
4098 expr_->SetMin(PosIntDivUp(m, value_));
4099 }
4100 }
4101
4102 int64_t Max() const override { return CapProd(expr_->Max(), value_); }
4103
4104 void SetMax(int64_t m) override {
4105 if (m != std::numeric_limits<int64_t>::max()) {
4106 expr_->SetMax(PosIntDivDown(m, value_));
4107 }
4108 }
4109
4110 IntVar* CastToVar() override {
4111 Solver* const s = solver();
4112 IntVar* var = nullptr;
4113 if (expr_->IsVar() &&
4114 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4115 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4116 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4117 } else {
4118 // TODO(user): Check overflows.
4119 var = s->RegisterIntVar(
4120 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4121 }
4122 return var;
4123 }
4124};
4125
4126// ----- TimesIntNegCstExpr -----
4127
4128class TimesIntNegCstExpr : public TimesIntCstExpr {
4129 public:
4130 TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4131 : TimesIntCstExpr(s, e, v) {
4132 CHECK_LT(v, 0);
4133 }
4134
4135 ~TimesIntNegCstExpr() override {}
4136
4137 int64_t Min() const override { return CapProd(expr_->Max(), value_); }
4138
4139 void SetMin(int64_t m) override {
4140 if (m != std::numeric_limits<int64_t>::min()) {
4141 expr_->SetMax(PosIntDivDown(-m, -value_));
4142 }
4143 }
4144
4145 int64_t Max() const override { return CapProd(expr_->Min(), value_); }
4146
4147 void SetMax(int64_t m) override {
4148 if (m != std::numeric_limits<int64_t>::max()) {
4149 expr_->SetMin(PosIntDivUp(-m, -value_));
4150 }
4151 }
4152
4153 IntVar* CastToVar() override {
4154 Solver* const s = solver();
4155 IntVar* var = nullptr;
4156 var = s->RegisterIntVar(
4157 s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4158 return var;
4159 }
4160};
4161
4162// ----- Utilities for product expression -----
4163
4164// Propagates set_min on left * right, left and right >= 0.
4165void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4166 DCHECK_GE(left->Min(), 0);
4167 DCHECK_GE(right->Min(), 0);
4168 const int64_t lmax = left->Max();
4169 const int64_t rmax = right->Max();
4170 if (m > CapProd(lmax, rmax)) {
4171 left->solver()->Fail();
4172 }
4173 if (m > CapProd(left->Min(), right->Min())) {
4174 // Ok for m == 0 due to left and right being positive
4175 if (0 != rmax) {
4176 left->SetMin(PosIntDivUp(m, rmax));
4177 }
4178 if (0 != lmax) {
4179 right->SetMin(PosIntDivUp(m, lmax));
4180 }
4181 }
4182}
4183
4184// Propagates set_max on left * right, left and right >= 0.
4185void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4186 DCHECK_GE(left->Min(), 0);
4187 DCHECK_GE(right->Min(), 0);
4188 const int64_t lmin = left->Min();
4189 const int64_t rmin = right->Min();
4190 if (m < CapProd(lmin, rmin)) {
4191 left->solver()->Fail();
4192 }
4193 if (m < CapProd(left->Max(), right->Max())) {
4194 if (0 != lmin) {
4195 right->SetMax(PosIntDivDown(m, lmin));
4196 }
4197 if (0 != rmin) {
4198 left->SetMax(PosIntDivDown(m, rmin));
4199 }
4200 // else do nothing: 0 is supporting any value from other expr.
4201 }
4202}
4203
4204// Propagates set_min on left * right, left >= 0, right across 0.
4205void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4206 DCHECK_GE(left->Min(), 0);
4207 DCHECK_GT(right->Max(), 0);
4208 DCHECK_LT(right->Min(), 0);
4209 const int64_t lmax = left->Max();
4210 const int64_t rmax = right->Max();
4211 if (m > CapProd(lmax, rmax)) {
4212 left->solver()->Fail();
4213 }
4214 if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4215 DCHECK_EQ(0, left->Min());
4216 DCHECK_LE(m, 0);
4217 } else {
4218 if (m > 0) { // We deduce right > 0.
4219 left->SetMin(PosIntDivUp(m, rmax));
4220 right->SetMin(PosIntDivUp(m, lmax));
4221 } else if (m == 0) {
4222 const int64_t lmin = left->Min();
4223 if (lmin > 0) {
4224 right->SetMin(0);
4225 }
4226 } else { // m < 0
4227 const int64_t lmin = left->Min();
4228 if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4229 right->SetMin(-PosIntDivDown(-m, lmin));
4230 }
4231 }
4232 }
4233}
4234
4235// Propagates set_min on left * right, left and right across 0.
4236void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4237 DCHECK_LT(left->Min(), 0);
4238 DCHECK_GT(left->Max(), 0);
4239 DCHECK_GT(right->Max(), 0);
4240 DCHECK_LT(right->Min(), 0);
4241 const int64_t lmin = left->Min();
4242 const int64_t lmax = left->Max();
4243 const int64_t rmin = right->Min();
4244 const int64_t rmax = right->Max();
4245 if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4246 left->solver()->Fail();
4247 }
4248 if (m >
4249 CapProd(lmin, rmin)) { // Must be positive section * positive section.
4250 left->SetMin(PosIntDivUp(m, rmax));
4251 right->SetMin(PosIntDivUp(m, lmax));
4252 } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4253 left->SetMax(CapOpp(PosIntDivUp(m, CapOpp(rmin))));
4254 right->SetMax(CapOpp(PosIntDivUp(m, CapOpp(lmin))));
4255 }
4256}
4257
4258void TimesSetMin(IntExpr* const left, IntExpr* const right,
4259 IntExpr* const minus_left, IntExpr* const minus_right,
4260 int64_t m) {
4261 if (left->Min() >= 0) {
4262 if (right->Min() >= 0) {
4263 SetPosPosMinExpr(left, right, m);
4264 } else if (right->Max() <= 0) {
4265 SetPosPosMaxExpr(left, minus_right, -m);
4266 } else { // right->Min() < 0 && right->Max() > 0
4267 SetPosGenMinExpr(left, right, m);
4268 }
4269 } else if (left->Max() <= 0) {
4270 if (right->Min() >= 0) {
4271 SetPosPosMaxExpr(right, minus_left, -m);
4272 } else if (right->Max() <= 0) {
4273 SetPosPosMinExpr(minus_left, minus_right, m);
4274 } else { // right->Min() < 0 && right->Max() > 0
4275 SetPosGenMinExpr(minus_left, minus_right, m);
4276 }
4277 } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4278 SetPosGenMinExpr(right, left, m);
4279 } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4280 SetPosGenMinExpr(minus_right, minus_left, m);
4281 } else { // left->Min() < 0 && left->Max() > 0 &&
4282 // right->Min() < 0 && right->Max() > 0
4283 SetGenGenMinExpr(left, right, m);
4284 }
4285}
4286
4287class TimesIntExpr : public BaseIntExpr {
4288 public:
4289 TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4290 : BaseIntExpr(s),
4291 left_(l),
4292 right_(r),
4293 minus_left_(s->MakeOpposite(left_)),
4294 minus_right_(s->MakeOpposite(right_)) {}
4295 ~TimesIntExpr() override {}
4296 int64_t Min() const override {
4297 const int64_t lmin = left_->Min();
4298 const int64_t lmax = left_->Max();
4299 const int64_t rmin = right_->Min();
4300 const int64_t rmax = right_->Max();
4301 return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4302 std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4303 }
4304 void SetMin(int64_t m) override;
4305 int64_t Max() const override {
4306 const int64_t lmin = left_->Min();
4307 const int64_t lmax = left_->Max();
4308 const int64_t rmin = right_->Min();
4309 const int64_t rmax = right_->Max();
4310 return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4311 std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4312 }
4313 void SetMax(int64_t m) override;
4314 bool Bound() const override;
4315 std::string name() const override {
4316 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4317 }
4318 std::string DebugString() const override {
4319 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4320 right_->DebugString());
4321 }
4322 void WhenRange(Demon* d) override {
4323 left_->WhenRange(d);
4324 right_->WhenRange(d);
4325 }
4326
4327 void Accept(ModelVisitor* const visitor) const override {
4328 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4329 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4330 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4331 right_);
4332 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4333 }
4334
4335 private:
4336 IntExpr* const left_;
4337 IntExpr* const right_;
4338 IntExpr* const minus_left_;
4339 IntExpr* const minus_right_;
4340};
4341
4342void TimesIntExpr::SetMin(int64_t m) {
4343 if (m != std::numeric_limits<int64_t>::min()) {
4344 TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4345 }
4346}
4347
4348void TimesIntExpr::SetMax(int64_t m) {
4349 if (m != std::numeric_limits<int64_t>::max()) {
4350 TimesSetMin(left_, minus_right_, minus_left_, right_, CapOpp(m));
4351 }
4352}
4353
4354bool TimesIntExpr::Bound() const {
4355 const bool left_bound = left_->Bound();
4356 const bool right_bound = right_->Bound();
4357 return ((left_bound && left_->Max() == 0) ||
4358 (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4359}
4360
4361// ----- TimesPosIntExpr -----
4362
4363class TimesPosIntExpr : public BaseIntExpr {
4364 public:
4365 TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4366 : BaseIntExpr(s), left_(l), right_(r) {}
4367 ~TimesPosIntExpr() override {}
4368 int64_t Min() const override { return (left_->Min() * right_->Min()); }
4369 void SetMin(int64_t m) override;
4370 int64_t Max() const override { return (left_->Max() * right_->Max()); }
4371 void SetMax(int64_t m) override;
4372 bool Bound() const override;
4373 std::string name() const override {
4374 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4375 }
4376 std::string DebugString() const override {
4377 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4378 right_->DebugString());
4379 }
4380 void WhenRange(Demon* d) override {
4381 left_->WhenRange(d);
4382 right_->WhenRange(d);
4383 }
4384
4385 void Accept(ModelVisitor* const visitor) const override {
4386 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4387 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4388 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4389 right_);
4390 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4391 }
4392
4393 private:
4394 IntExpr* const left_;
4395 IntExpr* const right_;
4396};
4397
4398void TimesPosIntExpr::SetMin(int64_t m) { SetPosPosMinExpr(left_, right_, m); }
4399
4400void TimesPosIntExpr::SetMax(int64_t m) { SetPosPosMaxExpr(left_, right_, m); }
4401
4402bool TimesPosIntExpr::Bound() const {
4403 return (left_->Max() == 0 || right_->Max() == 0 ||
4404 (left_->Bound() && right_->Bound()));
4405}
4406
4407// ----- SafeTimesPosIntExpr -----
4408
4409class SafeTimesPosIntExpr : public BaseIntExpr {
4410 public:
4411 SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4412 : BaseIntExpr(s), left_(l), right_(r) {}
4413 ~SafeTimesPosIntExpr() override {}
4414 int64_t Min() const override { return CapProd(left_->Min(), right_->Min()); }
4415 void SetMin(int64_t m) override {
4416 if (m != std::numeric_limits<int64_t>::min()) {
4417 SetPosPosMinExpr(left_, right_, m);
4418 }
4419 }
4420 int64_t Max() const override { return CapProd(left_->Max(), right_->Max()); }
4421 void SetMax(int64_t m) override {
4422 if (m != std::numeric_limits<int64_t>::max()) {
4423 SetPosPosMaxExpr(left_, right_, m);
4424 }
4425 }
4426 bool Bound() const override {
4427 return (left_->Max() == 0 || right_->Max() == 0 ||
4428 (left_->Bound() && right_->Bound()));
4429 }
4430 std::string name() const override {
4431 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4432 }
4433 std::string DebugString() const override {
4434 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4435 right_->DebugString());
4436 }
4437 void WhenRange(Demon* d) override {
4438 left_->WhenRange(d);
4439 right_->WhenRange(d);
4440 }
4441
4442 void Accept(ModelVisitor* const visitor) const override {
4443 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4444 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4445 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4446 right_);
4447 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4448 }
4449
4450 private:
4451 IntExpr* const left_;
4452 IntExpr* const right_;
4453};
4454
4455// ----- TimesBooleanPosIntExpr -----
4456
4457class TimesBooleanPosIntExpr : public BaseIntExpr {
4458 public:
4459 TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4460 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4461 ~TimesBooleanPosIntExpr() override {}
4462 int64_t Min() const override {
4463 return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4464 }
4465 void SetMin(int64_t m) override;
4466 int64_t Max() const override {
4467 return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4468 }
4469 void SetMax(int64_t m) override;
4470 void Range(int64_t* mi, int64_t* ma) override;
4471 void SetRange(int64_t mi, int64_t ma) override;
4472 bool Bound() const override;
4473 std::string name() const override {
4474 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4475 }
4476 std::string DebugString() const override {
4477 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4478 expr_->DebugString());
4479 }
4480 void WhenRange(Demon* d) override {
4481 boolvar_->WhenRange(d);
4482 expr_->WhenRange(d);
4483 }
4484
4485 void Accept(ModelVisitor* const visitor) const override {
4486 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4487 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4488 boolvar_);
4489 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4490 expr_);
4491 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4492 }
4493
4494 private:
4495 BooleanVar* const boolvar_;
4496 IntExpr* const expr_;
4497};
4498
4499void TimesBooleanPosIntExpr::SetMin(int64_t m) {
4500 if (m > 0) {
4501 boolvar_->SetValue(1);
4502 expr_->SetMin(m);
4503 }
4504}
4505
4506void TimesBooleanPosIntExpr::SetMax(int64_t m) {
4507 if (m < 0) {
4508 solver()->Fail();
4509 }
4510 if (m < expr_->Min()) {
4511 boolvar_->SetValue(0);
4512 }
4513 if (boolvar_->RawValue() == 1) {
4514 expr_->SetMax(m);
4515 }
4516}
4517
4518void TimesBooleanPosIntExpr::Range(int64_t* mi, int64_t* ma) {
4519 const int value = boolvar_->RawValue();
4520 if (value == 0) {
4521 *mi = 0;
4522 *ma = 0;
4523 } else if (value == 1) {
4524 expr_->Range(mi, ma);
4525 } else {
4526 *mi = 0;
4527 *ma = expr_->Max();
4528 }
4529}
4530
4531void TimesBooleanPosIntExpr::SetRange(int64_t mi, int64_t ma) {
4532 if (ma < 0 || mi > ma) {
4533 solver()->Fail();
4534 }
4535 if (mi > 0) {
4536 boolvar_->SetValue(1);
4537 expr_->SetMin(mi);
4538 }
4539 if (ma < expr_->Min()) {
4540 boolvar_->SetValue(0);
4541 }
4542 if (boolvar_->RawValue() == 1) {
4543 expr_->SetMax(ma);
4544 }
4545}
4546
4547bool TimesBooleanPosIntExpr::Bound() const {
4548 return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4549 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4550 expr_->Bound()));
4551}
4552
4553// ----- TimesBooleanIntExpr -----
4554
4555class TimesBooleanIntExpr : public BaseIntExpr {
4556 public:
4557 TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4558 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4559 ~TimesBooleanIntExpr() override {}
4560 int64_t Min() const override {
4561 switch (boolvar_->RawValue()) {
4562 case 0: {
4563 return 0LL;
4564 }
4565 case 1: {
4566 return expr_->Min();
4567 }
4568 default: {
4569 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4570 return std::min(int64_t{0}, expr_->Min());
4571 }
4572 }
4573 }
4574 void SetMin(int64_t m) override;
4575 int64_t Max() const override {
4576 switch (boolvar_->RawValue()) {
4577 case 0: {
4578 return 0LL;
4579 }
4580 case 1: {
4581 return expr_->Max();
4582 }
4583 default: {
4584 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4585 return std::max(int64_t{0}, expr_->Max());
4586 }
4587 }
4588 }
4589 void SetMax(int64_t m) override;
4590 void Range(int64_t* mi, int64_t* ma) override;
4591 void SetRange(int64_t mi, int64_t ma) override;
4592 bool Bound() const override;
4593 std::string name() const override {
4594 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4595 }
4596 std::string DebugString() const override {
4597 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4598 expr_->DebugString());
4599 }
4600 void WhenRange(Demon* d) override {
4601 boolvar_->WhenRange(d);
4602 expr_->WhenRange(d);
4603 }
4604
4605 void Accept(ModelVisitor* const visitor) const override {
4606 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4607 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4608 boolvar_);
4609 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4610 expr_);
4611 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4612 }
4613
4614 private:
4615 BooleanVar* const boolvar_;
4616 IntExpr* const expr_;
4617};
4618
4619void TimesBooleanIntExpr::SetMin(int64_t m) {
4620 switch (boolvar_->RawValue()) {
4621 case 0: {
4622 if (m > 0) {
4623 solver()->Fail();
4624 }
4625 break;
4626 }
4627 case 1: {
4628 expr_->SetMin(m);
4629 break;
4630 }
4631 default: {
4632 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4633 if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4634 boolvar_->SetValue(1);
4635 expr_->SetMin(m);
4636 } else if (m <= 0 && expr_->Max() < m) {
4637 boolvar_->SetValue(0);
4638 }
4639 }
4640 }
4641}
4642
4643void TimesBooleanIntExpr::SetMax(int64_t m) {
4644 switch (boolvar_->RawValue()) {
4645 case 0: {
4646 if (m < 0) {
4647 solver()->Fail();
4648 }
4649 break;
4650 }
4651 case 1: {
4652 expr_->SetMax(m);
4653 break;
4654 }
4655 default: {
4656 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4657 if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4658 boolvar_->SetValue(1);
4659 expr_->SetMax(m);
4660 } else if (m >= 0 && expr_->Min() > m) {
4661 boolvar_->SetValue(0);
4662 }
4663 }
4664 }
4665}
4666
4667void TimesBooleanIntExpr::Range(int64_t* mi, int64_t* ma) {
4668 switch (boolvar_->RawValue()) {
4669 case 0: {
4670 *mi = 0;
4671 *ma = 0;
4672 break;
4673 }
4674 case 1: {
4675 *mi = expr_->Min();
4676 *ma = expr_->Max();
4677 break;
4678 }
4679 default: {
4680 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4681 *mi = std::min(int64_t{0}, expr_->Min());
4682 *ma = std::max(int64_t{0}, expr_->Max());
4683 break;
4684 }
4685 }
4686}
4687
4688void TimesBooleanIntExpr::SetRange(int64_t mi, int64_t ma) {
4689 if (mi > ma) {
4690 solver()->Fail();
4691 }
4692 switch (boolvar_->RawValue()) {
4693 case 0: {
4694 if (mi > 0 || ma < 0) {
4695 solver()->Fail();
4696 }
4697 break;
4698 }
4699 case 1: {
4700 expr_->SetRange(mi, ma);
4701 break;
4702 }
4703 default: {
4704 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4705 if (mi > 0) {
4706 boolvar_->SetValue(1);
4707 expr_->SetMin(mi);
4708 } else if (mi == 0 && expr_->Max() < 0) {
4709 boolvar_->SetValue(0);
4710 }
4711 if (ma < 0) {
4712 boolvar_->SetValue(1);
4713 expr_->SetMax(ma);
4714 } else if (ma == 0 && expr_->Min() > 0) {
4715 boolvar_->SetValue(0);
4716 }
4717 break;
4718 }
4719 }
4720}
4721
4722bool TimesBooleanIntExpr::Bound() const {
4723 return (boolvar_->RawValue() == 0 ||
4724 (expr_->Bound() &&
4725 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4726 expr_->Max() == 0)));
4727}
4728
4729// ----- DivPosIntCstExpr -----
4730
4731class DivPosIntCstExpr : public BaseIntExpr {
4732 public:
4733 DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4734 : BaseIntExpr(s), expr_(e), value_(v) {
4735 CHECK_GE(v, 0);
4736 }
4737 ~DivPosIntCstExpr() override {}
4738
4739 int64_t Min() const override { return expr_->Min() / value_; }
4740
4741 void SetMin(int64_t m) override {
4742 if (m > 0) {
4743 expr_->SetMin(m * value_);
4744 } else {
4745 expr_->SetMin((m - 1) * value_ + 1);
4746 }
4747 }
4748 int64_t Max() const override { return expr_->Max() / value_; }
4749
4750 void SetMax(int64_t m) override {
4751 if (m >= 0) {
4752 expr_->SetMax((m + 1) * value_ - 1);
4753 } else {
4754 expr_->SetMax(m * value_);
4755 }
4756 }
4757
4758 std::string name() const override {
4759 return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4760 }
4761
4762 std::string DebugString() const override {
4763 return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4764 }
4765
4766 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4767
4768 void Accept(ModelVisitor* const visitor) const override {
4769 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4770 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4771 expr_);
4772 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4773 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4774 }
4775
4776 private:
4777 IntExpr* const expr_;
4778 const int64_t value_;
4779};
4780
4781// DivPosIntExpr
4782
4783class DivPosIntExpr : public BaseIntExpr {
4784 public:
4785 DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4786 : BaseIntExpr(s),
4787 num_(num),
4788 denom_(denom),
4789 opp_num_(s->MakeOpposite(num)) {}
4790
4791 ~DivPosIntExpr() override {}
4792
4793 int64_t Min() const override {
4794 return num_->Min() >= 0
4795 ? num_->Min() / denom_->Max()
4796 : (denom_->Min() == 0 ? num_->Min()
4797 : num_->Min() / denom_->Min());
4798 }
4799
4800 int64_t Max() const override {
4801 return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4802 : num_->Max() / denom_->Min())
4803 : num_->Max() / denom_->Max();
4804 }
4805
4806 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4807 num->SetMin(m * denom->Min());
4808 denom->SetMax(num->Max() / m);
4809 }
4810
4811 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
4812 num->SetMax((m + 1) * denom->Max() - 1);
4813 denom->SetMin(num->Min() / (m + 1) + 1);
4814 }
4815
4816 void SetMin(int64_t m) override {
4817 if (m > 0) {
4818 SetPosMin(num_, denom_, m);
4819 } else {
4820 SetPosMax(opp_num_, denom_, -m);
4821 }
4822 }
4823
4824 void SetMax(int64_t m) override {
4825 if (m >= 0) {
4826 SetPosMax(num_, denom_, m);
4827 } else {
4828 SetPosMin(opp_num_, denom_, -m);
4829 }
4830 }
4831
4832 std::string name() const override {
4833 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4834 }
4835 std::string DebugString() const override {
4836 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4837 denom_->DebugString());
4838 }
4839 void WhenRange(Demon* d) override {
4840 num_->WhenRange(d);
4841 denom_->WhenRange(d);
4842 }
4843
4844 void Accept(ModelVisitor* const visitor) const override {
4845 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4846 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4847 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4848 denom_);
4849 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4850 }
4851
4852 private:
4853 IntExpr* const num_;
4854 IntExpr* const denom_;
4855 IntExpr* const opp_num_;
4856};
4857
4858class DivPosPosIntExpr : public BaseIntExpr {
4859 public:
4860 DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4861 : BaseIntExpr(s), num_(num), denom_(denom) {}
4862
4863 ~DivPosPosIntExpr() override {}
4864
4865 int64_t Min() const override {
4866 if (denom_->Max() == 0) {
4867 solver()->Fail();
4868 }
4869 return num_->Min() / denom_->Max();
4870 }
4871
4872 int64_t Max() const override {
4873 if (denom_->Min() == 0) {
4874 return num_->Max();
4875 } else {
4876 return num_->Max() / denom_->Min();
4877 }
4878 }
4879
4880 void SetMin(int64_t m) override {
4881 if (m > 0) {
4882 num_->SetMin(m * denom_->Min());
4883 denom_->SetMax(num_->Max() / m);
4884 }
4885 }
4886
4887 void SetMax(int64_t m) override {
4888 if (m >= 0) {
4889 num_->SetMax((m + 1) * denom_->Max() - 1);
4890 denom_->SetMin(num_->Min() / (m + 1) + 1);
4891 } else {
4892 solver()->Fail();
4893 }
4894 }
4895
4896 std::string name() const override {
4897 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4898 }
4899
4900 std::string DebugString() const override {
4901 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4902 denom_->DebugString());
4903 }
4904
4905 void WhenRange(Demon* d) override {
4906 num_->WhenRange(d);
4907 denom_->WhenRange(d);
4908 }
4909
4910 void Accept(ModelVisitor* const visitor) const override {
4911 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4912 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4913 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4914 denom_);
4915 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4916 }
4917
4918 private:
4919 IntExpr* const num_;
4920 IntExpr* const denom_;
4921};
4922
4923// DivIntExpr
4924
4925class DivIntExpr : public BaseIntExpr {
4926 public:
4927 DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4928 : BaseIntExpr(s),
4929 num_(num),
4930 denom_(denom),
4931 opp_num_(s->MakeOpposite(num)) {}
4932
4933 ~DivIntExpr() override {}
4934
4935 int64_t Min() const override {
4936 const int64_t num_min = num_->Min();
4937 const int64_t num_max = num_->Max();
4938 const int64_t denom_min = denom_->Min();
4939 const int64_t denom_max = denom_->Max();
4940
4941 if (denom_min == 0 && denom_max == 0) {
4942 return std::numeric_limits<int64_t>::max(); // TODO(user): Check this
4943 // convention.
4944 }
4945
4946 if (denom_min >= 0) { // Denominator strictly positive.
4947 DCHECK_GT(denom_max, 0);
4948 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4949 return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4950 } else if (denom_max <= 0) { // Denominator strictly negative.
4951 DCHECK_LT(denom_min, 0);
4952 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4953 return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4954 } else { // Denominator across 0.
4955 return std::min(num_min, -num_max);
4956 }
4957 }
4958
4959 int64_t Max() const override {
4960 const int64_t num_min = num_->Min();
4961 const int64_t num_max = num_->Max();
4962 const int64_t denom_min = denom_->Min();
4963 const int64_t denom_max = denom_->Max();
4964
4965 if (denom_min == 0 && denom_max == 0) {
4966 return std::numeric_limits<int64_t>::min(); // TODO(user): Check this
4967 // convention.
4968 }
4969
4970 if (denom_min >= 0) { // Denominator strictly positive.
4971 DCHECK_GT(denom_max, 0);
4972 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4973 return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4974 } else if (denom_max <= 0) { // Denominator strictly negative.
4975 DCHECK_LT(denom_min, 0);
4976 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4977 return num_min >= 0 ? num_min / denom_min
4978 : -num_min / -adjusted_denom_max;
4979 } else { // Denominator across 0.
4980 return std::max(num_max, -num_min);
4981 }
4982 }
4983
4984 void AdjustDenominator() {
4985 if (denom_->Min() == 0) {
4986 denom_->SetMin(1);
4987 } else if (denom_->Max() == 0) {
4988 denom_->SetMax(-1);
4989 }
4990 }
4991
4992 // m > 0.
4993 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4994 DCHECK_GT(m, 0);
4995 const int64_t num_min = num->Min();
4996 const int64_t num_max = num->Max();
4997 const int64_t denom_min = denom->Min();
4998 const int64_t denom_max = denom->Max();
4999 DCHECK_NE(denom_min, 0);
5000 DCHECK_NE(denom_max, 0);
5001 if (denom_min > 0) { // Denominator strictly positive.
5002 num->SetMin(m * denom_min);
5003 denom->SetMax(num_max / m);
5004 } else if (denom_max < 0) { // Denominator strictly negative.
5005 num->SetMax(m * denom_max);
5006 denom->SetMin(num_min / m);
5007 } else { // Denominator across 0.
5008 if (num_min >= 0) {
5009 num->SetMin(m);
5010 denom->SetRange(1, num_max / m);
5011 } else if (num_max <= 0) {
5012 num->SetMax(-m);
5013 denom->SetRange(num_min / m, -1);
5014 } else {
5015 if (m > -num_min) { // Denominator is forced positive.
5016 num->SetMin(m);
5017 denom->SetRange(1, num_max / m);
5018 } else if (m > num_max) { // Denominator is forced negative.
5019 num->SetMax(-m);
5020 denom->SetRange(num_min / m, -1);
5021 } else {
5022 denom->SetRange(num_min / m, num_max / m);
5023 }
5024 }
5025 }
5026 }
5027
5028 // m >= 0.
5029 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
5030 DCHECK_GE(m, 0);
5031 const int64_t num_min = num->Min();
5032 const int64_t num_max = num->Max();
5033 const int64_t denom_min = denom->Min();
5034 const int64_t denom_max = denom->Max();
5035 DCHECK_NE(denom_min, 0);
5036 DCHECK_NE(denom_max, 0);
5037 if (denom_min > 0) { // Denominator strictly positive.
5038 num->SetMax((m + 1) * denom_max - 1);
5039 denom->SetMin((num_min / (m + 1)) + 1);
5040 } else if (denom_max < 0) {
5041 num->SetMin((m + 1) * denom_min + 1);
5042 denom->SetMax(num_max / (m + 1) - 1);
5043 } else if (num_min > (m + 1) * denom_max - 1) {
5044 denom->SetMax(-1);
5045 } else if (num_max < (m + 1) * denom_min + 1) {
5046 denom->SetMin(1);
5047 }
5048 }
5049
5050 void SetMin(int64_t m) override {
5051 AdjustDenominator();
5052 if (m > 0) {
5053 SetPosMin(num_, denom_, m);
5054 } else {
5055 SetPosMax(opp_num_, denom_, -m);
5056 }
5057 }
5058
5059 void SetMax(int64_t m) override {
5060 AdjustDenominator();
5061 if (m >= 0) {
5062 SetPosMax(num_, denom_, m);
5063 } else {
5064 SetPosMin(opp_num_, denom_, -m);
5065 }
5066 }
5067
5068 std::string name() const override {
5069 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5070 }
5071 std::string DebugString() const override {
5072 return absl::StrFormat("(%s div %s)", num_->DebugString(),
5073 denom_->DebugString());
5074 }
5075 void WhenRange(Demon* d) override {
5076 num_->WhenRange(d);
5077 denom_->WhenRange(d);
5078 }
5079
5080 void Accept(ModelVisitor* const visitor) const override {
5081 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5082 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5083 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5084 denom_);
5085 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5086 }
5087
5088 private:
5089 IntExpr* const num_;
5090 IntExpr* const denom_;
5091 IntExpr* const opp_num_;
5092};
5093
5094// ----- IntAbs And IntAbsConstraint ------
5095
5096class IntAbsConstraint : public CastConstraint {
5097 public:
5098 IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5099 : CastConstraint(s, target), sub_(sub) {}
5100
5101 ~IntAbsConstraint() override {}
5102
5103 void Post() override {
5104 Demon* const sub_demon = MakeConstraintDemon0(
5105 solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5106 sub_->WhenRange(sub_demon);
5107 Demon* const target_demon = MakeConstraintDemon0(
5108 solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5109 target_var_->WhenRange(target_demon);
5110 }
5111
5112 void InitialPropagate() override {
5113 PropagateSub();
5114 PropagateTarget();
5115 }
5116
5117 void PropagateSub() {
5118 const int64_t smin = sub_->Min();
5119 const int64_t smax = sub_->Max();
5120 if (smax <= 0) {
5121 target_var_->SetRange(-smax, -smin);
5122 } else if (smin >= 0) {
5123 target_var_->SetRange(smin, smax);
5124 } else {
5125 target_var_->SetRange(0, std::max(-smin, smax));
5126 }
5127 }
5128
5129 void PropagateTarget() {
5130 const int64_t target_max = target_var_->Max();
5131 sub_->SetRange(-target_max, target_max);
5132 const int64_t target_min = target_var_->Min();
5133 if (target_min > 0) {
5134 if (sub_->Min() > -target_min) {
5135 sub_->SetMin(target_min);
5136 } else if (sub_->Max() < target_min) {
5137 sub_->SetMax(-target_min);
5138 }
5139 }
5140 }
5141
5142 std::string DebugString() const override {
5143 return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5144 target_var_->DebugString());
5145 }
5146
5147 void Accept(ModelVisitor* const visitor) const override {
5148 visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5149 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5150 sub_);
5151 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5152 target_var_);
5153 visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5154 }
5155
5156 private:
5157 IntVar* const sub_;
5158};
5159
5160class IntAbs : public BaseIntExpr {
5161 public:
5162 IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5163
5164 ~IntAbs() override {}
5165
5166 int64_t Min() const override {
5167 int64_t emin = 0;
5168 int64_t emax = 0;
5169 expr_->Range(&emin, &emax);
5170 if (emin >= 0) {
5171 return emin;
5172 }
5173 if (emax <= 0) {
5174 return -emax;
5175 }
5176 return 0;
5177 }
5178
5179 void SetMin(int64_t m) override {
5180 if (m > 0) {
5181 int64_t emin = 0;
5182 int64_t emax = 0;
5183 expr_->Range(&emin, &emax);
5184 if (emin > -m) {
5185 expr_->SetMin(m);
5186 } else if (emax < m) {
5187 expr_->SetMax(-m);
5188 }
5189 }
5190 }
5191
5192 int64_t Max() const override {
5193 int64_t emin = 0;
5194 int64_t emax = 0;
5195 expr_->Range(&emin, &emax);
5196 return std::max(-emin, emax);
5197 }
5198
5199 void SetMax(int64_t m) override { expr_->SetRange(-m, m); }
5200
5201 void SetRange(int64_t mi, int64_t ma) override {
5202 expr_->SetRange(-ma, ma);
5203 if (mi > 0) {
5204 int64_t emin = 0;
5205 int64_t emax = 0;
5206 expr_->Range(&emin, &emax);
5207 if (emin > -mi) {
5208 expr_->SetMin(mi);
5209 } else if (emax < mi) {
5210 expr_->SetMax(-mi);
5211 }
5212 }
5213 }
5214
5215 void Range(int64_t* mi, int64_t* ma) override {
5216 int64_t emin = 0;
5217 int64_t emax = 0;
5218 expr_->Range(&emin, &emax);
5219 if (emin >= 0) {
5220 *mi = emin;
5221 *ma = emax;
5222 } else if (emax <= 0) {
5223 *mi = -emax;
5224 *ma = -emin;
5225 } else {
5226 *mi = 0;
5227 *ma = std::max(-emin, emax);
5228 }
5229 }
5230
5231 bool Bound() const override { return expr_->Bound(); }
5232
5233 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5234
5235 std::string name() const override {
5236 return absl::StrFormat("IntAbs(%s)", expr_->name());
5237 }
5238
5239 std::string DebugString() const override {
5240 return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5241 }
5242
5243 void Accept(ModelVisitor* const visitor) const override {
5244 visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5245 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5246 expr_);
5247 visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5248 }
5249
5250 IntVar* CastToVar() override {
5251 int64_t min_value = 0;
5252 int64_t max_value = 0;
5253 Range(&min_value, &max_value);
5254 Solver* const s = solver();
5255 const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5256 IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5257 CastConstraint* const ct =
5258 s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5259 s->AddCastConstraint(ct, target, this);
5260 return target;
5261 }
5262
5263 private:
5264 IntExpr* const expr_;
5265};
5266
5267// ----- Square -----
5268
5269// TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5270class IntSquare : public BaseIntExpr {
5271 public:
5272 IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5273 ~IntSquare() override {}
5274
5275 int64_t Min() const override {
5276 const int64_t emin = expr_->Min();
5277 if (emin >= 0) {
5278 return emin >= std::numeric_limits<int32_t>::max()
5279 ? std::numeric_limits<int64_t>::max()
5280 : emin * emin;
5281 }
5282 const int64_t emax = expr_->Max();
5283 if (emax < 0) {
5284 return emax <= -std::numeric_limits<int32_t>::max()
5285 ? std::numeric_limits<int64_t>::max()
5286 : emax * emax;
5287 }
5288 return 0LL;
5289 }
5290 void SetMin(int64_t m) override {
5291 if (m <= 0) {
5292 return;
5293 }
5294 // TODO(user): What happens if m is kint64max?
5295 const int64_t emin = expr_->Min();
5296 const int64_t emax = expr_->Max();
5297 const int64_t root =
5298 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5299 if (emin >= 0) {
5300 expr_->SetMin(root);
5301 } else if (emax <= 0) {
5302 expr_->SetMax(-root);
5303 } else if (expr_->IsVar()) {
5304 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5305 }
5306 }
5307 int64_t Max() const override {
5308 const int64_t emax = expr_->Max();
5309 const int64_t emin = expr_->Min();
5310 if (emax >= std::numeric_limits<int32_t>::max() ||
5311 emin <= -std::numeric_limits<int32_t>::max()) {
5312 return std::numeric_limits<int64_t>::max();
5313 }
5314 return std::max(emin * emin, emax * emax);
5315 }
5316 void SetMax(int64_t m) override {
5317 if (m < 0) {
5318 solver()->Fail();
5319 }
5320 if (m == std::numeric_limits<int64_t>::max()) {
5321 return;
5322 }
5323 const int64_t root =
5324 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5325 expr_->SetRange(-root, root);
5326 }
5327 bool Bound() const override { return expr_->Bound(); }
5328 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5329 std::string name() const override {
5330 return absl::StrFormat("IntSquare(%s)", expr_->name());
5331 }
5332 std::string DebugString() const override {
5333 return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5334 }
5335
5336 void Accept(ModelVisitor* const visitor) const override {
5337 visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5338 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5339 expr_);
5340 visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5341 }
5342
5343 IntExpr* expr() const { return expr_; }
5344
5345 protected:
5346 IntExpr* const expr_;
5347};
5348
5349class PosIntSquare : public IntSquare {
5350 public:
5351 PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5352 ~PosIntSquare() override {}
5353
5354 int64_t Min() const override {
5355 const int64_t emin = expr_->Min();
5356 return emin >= std::numeric_limits<int32_t>::max()
5357 ? std::numeric_limits<int64_t>::max()
5358 : emin * emin;
5359 }
5360 void SetMin(int64_t m) override {
5361 if (m <= 0) {
5362 return;
5363 }
5364 int64_t root = static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5365 if (CapProd(root, root) < m) {
5366 root++;
5367 }
5368 expr_->SetMin(root);
5369 }
5370 int64_t Max() const override {
5371 const int64_t emax = expr_->Max();
5372 return emax >= std::numeric_limits<int32_t>::max()
5373 ? std::numeric_limits<int64_t>::max()
5374 : emax * emax;
5375 }
5376 void SetMax(int64_t m) override {
5377 if (m < 0) {
5378 solver()->Fail();
5379 }
5380 if (m == std::numeric_limits<int64_t>::max()) {
5381 return;
5382 }
5383 int64_t root = static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5384 if (CapProd(root, root) > m) {
5385 root--;
5386 }
5387
5388 expr_->SetMax(root);
5389 }
5390};
5391
5392// ----- EvenPower -----
5393
5394int64_t IntPower(int64_t value, int64_t power) {
5395 int64_t result = value;
5396 // TODO(user): Speed that up.
5397 for (int i = 1; i < power; ++i) {
5398 result *= value;
5399 }
5400 return result;
5401}
5402
5403int64_t OverflowLimit(int64_t power) {
5404 return static_cast<int64_t>(floor(exp(
5405 log(static_cast<double>(std::numeric_limits<int64_t>::max())) / power)));
5406}
5407
5408class BasePower : public BaseIntExpr {
5409 public:
5410 BasePower(Solver* const s, IntExpr* const e, int64_t n)
5411 : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5412 CHECK_GT(n, 0);
5413 }
5414
5415 ~BasePower() override {}
5416
5417 bool Bound() const override { return expr_->Bound(); }
5418
5419 IntExpr* expr() const { return expr_; }
5420
5421 int64_t exponant() const { return pow_; }
5422
5423 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5424
5425 std::string name() const override {
5426 return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5427 }
5428
5429 std::string DebugString() const override {
5430 return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5431 }
5432
5433 void Accept(ModelVisitor* const visitor) const override {
5434 visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5435 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5436 expr_);
5437 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5438 visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5439 }
5440
5441 protected:
5442 int64_t Pown(int64_t value) const {
5443 if (value >= limit_) {
5444 return std::numeric_limits<int64_t>::max();
5445 }
5446 if (value <= -limit_) {
5447 if (pow_ % 2 == 0) {
5448 return std::numeric_limits<int64_t>::max();
5449 } else {
5450 return std::numeric_limits<int64_t>::min();
5451 }
5452 }
5453 return IntPower(value, pow_);
5454 }
5455
5456 int64_t SqrnDown(int64_t value) const {
5457 if (value == std::numeric_limits<int64_t>::min()) {
5458 return std::numeric_limits<int64_t>::min();
5459 }
5460 if (value == std::numeric_limits<int64_t>::max()) {
5461 return std::numeric_limits<int64_t>::max();
5462 }
5463 int64_t res = 0;
5464 const double d_value = static_cast<double>(value);
5465 if (value >= 0) {
5466 const double sq = exp(log(d_value) / pow_);
5467 res = static_cast<int64_t>(floor(sq));
5468 } else {
5469 CHECK_EQ(1, pow_ % 2);
5470 const double sq = exp(log(-d_value) / pow_);
5471 res = -static_cast<int64_t>(ceil(sq));
5472 }
5473 const int64_t pow_res = Pown(res + 1);
5474 if (pow_res <= value) {
5475 return res + 1;
5476 } else {
5477 return res;
5478 }
5479 }
5480
5481 int64_t SqrnUp(int64_t value) const {
5482 if (value == std::numeric_limits<int64_t>::min()) {
5483 return std::numeric_limits<int64_t>::min();
5484 }
5485 if (value == std::numeric_limits<int64_t>::max()) {
5486 return std::numeric_limits<int64_t>::max();
5487 }
5488 int64_t res = 0;
5489 const double d_value = static_cast<double>(value);
5490 if (value >= 0) {
5491 const double sq = exp(log(d_value) / pow_);
5492 res = static_cast<int64_t>(ceil(sq));
5493 } else {
5494 CHECK_EQ(1, pow_ % 2);
5495 const double sq = exp(log(-d_value) / pow_);
5496 res = -static_cast<int64_t>(floor(sq));
5497 }
5498 const int64_t pow_res = Pown(res - 1);
5499 if (pow_res >= value) {
5500 return res - 1;
5501 } else {
5502 return res;
5503 }
5504 }
5505
5506 IntExpr* const expr_;
5507 const int64_t pow_;
5508 const int64_t limit_;
5509};
5510
5511class IntEvenPower : public BasePower {
5512 public:
5513 IntEvenPower(Solver* const s, IntExpr* const e, int64_t n)
5514 : BasePower(s, e, n) {
5515 CHECK_EQ(0, n % 2);
5516 }
5517
5518 ~IntEvenPower() override {}
5519
5520 int64_t Min() const override {
5521 int64_t emin = 0;
5522 int64_t emax = 0;
5523 expr_->Range(&emin, &emax);
5524 if (emin >= 0) {
5525 return Pown(emin);
5526 }
5527 if (emax < 0) {
5528 return Pown(emax);
5529 }
5530 return 0LL;
5531 }
5532 void SetMin(int64_t m) override {
5533 if (m <= 0) {
5534 return;
5535 }
5536 int64_t emin = 0;
5537 int64_t emax = 0;
5538 expr_->Range(&emin, &emax);
5539 const int64_t root = SqrnUp(m);
5540 if (emin > -root) {
5541 expr_->SetMin(root);
5542 } else if (emax < root) {
5543 expr_->SetMax(-root);
5544 } else if (expr_->IsVar()) {
5545 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5546 }
5547 }
5548
5549 int64_t Max() const override {
5550 return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5551 }
5552
5553 void SetMax(int64_t m) override {
5554 if (m < 0) {
5555 solver()->Fail();
5556 }
5557 if (m == std::numeric_limits<int64_t>::max()) {
5558 return;
5559 }
5560 const int64_t root = SqrnDown(m);
5561 expr_->SetRange(-root, root);
5562 }
5563};
5564
5565class PosIntEvenPower : public BasePower {
5566 public:
5567 PosIntEvenPower(Solver* const s, IntExpr* const e, int64_t pow)
5568 : BasePower(s, e, pow) {
5569 CHECK_EQ(0, pow % 2);
5570 }
5571
5572 ~PosIntEvenPower() override {}
5573
5574 int64_t Min() const override { return Pown(expr_->Min()); }
5575
5576 void SetMin(int64_t m) override {
5577 if (m <= 0) {
5578 return;
5579 }
5580 expr_->SetMin(SqrnUp(m));
5581 }
5582 int64_t Max() const override { return Pown(expr_->Max()); }
5583
5584 void SetMax(int64_t m) override {
5585 if (m < 0) {
5586 solver()->Fail();
5587 }
5588 if (m == std::numeric_limits<int64_t>::max()) {
5589 return;
5590 }
5591 expr_->SetMax(SqrnDown(m));
5592 }
5593};
5594
5595class IntOddPower : public BasePower {
5596 public:
5597 IntOddPower(Solver* const s, IntExpr* const e, int64_t n)
5598 : BasePower(s, e, n) {
5599 CHECK_EQ(1, n % 2);
5600 }
5601
5602 ~IntOddPower() override {}
5603
5604 int64_t Min() const override { return Pown(expr_->Min()); }
5605
5606 void SetMin(int64_t m) override { expr_->SetMin(SqrnUp(m)); }
5607
5608 int64_t Max() const override { return Pown(expr_->Max()); }
5609
5610 void SetMax(int64_t m) override { expr_->SetMax(SqrnDown(m)); }
5611};
5612
5613// ----- Min(expr, expr) -----
5614
5615class MinIntExpr : public BaseIntExpr {
5616 public:
5617 MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5618 : BaseIntExpr(s), left_(l), right_(r) {}
5619 ~MinIntExpr() override {}
5620 int64_t Min() const override {
5621 const int64_t lmin = left_->Min();
5622 const int64_t rmin = right_->Min();
5623 return std::min(lmin, rmin);
5624 }
5625 void SetMin(int64_t m) override {
5626 left_->SetMin(m);
5627 right_->SetMin(m);
5628 }
5629 int64_t Max() const override {
5630 const int64_t lmax = left_->Max();
5631 const int64_t rmax = right_->Max();
5632 return std::min(lmax, rmax);
5633 }
5634 void SetMax(int64_t m) override {
5635 if (left_->Min() > m) {
5636 right_->SetMax(m);
5637 }
5638 if (right_->Min() > m) {
5639 left_->SetMax(m);
5640 }
5641 }
5642 std::string name() const override {
5643 return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5644 }
5645 std::string DebugString() const override {
5646 return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5647 right_->DebugString());
5648 }
5649 void WhenRange(Demon* d) override {
5650 left_->WhenRange(d);
5651 right_->WhenRange(d);
5652 }
5653
5654 void Accept(ModelVisitor* const visitor) const override {
5655 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5656 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5657 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5658 right_);
5659 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5660 }
5661
5662 private:
5663 IntExpr* const left_;
5664 IntExpr* const right_;
5665};
5666
5667// ----- Min(expr, constant) -----
5668
5669class MinCstIntExpr : public BaseIntExpr {
5670 public:
5671 MinCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5672 : BaseIntExpr(s), expr_(e), value_(v) {}
5673
5674 ~MinCstIntExpr() override {}
5675
5676 int64_t Min() const override { return std::min(expr_->Min(), value_); }
5677
5678 void SetMin(int64_t m) override {
5679 if (m > value_) {
5680 solver()->Fail();
5681 }
5682 expr_->SetMin(m);
5683 }
5684
5685 int64_t Max() const override { return std::min(expr_->Max(), value_); }
5686
5687 void SetMax(int64_t m) override {
5688 if (value_ > m) {
5689 expr_->SetMax(m);
5690 }
5691 }
5692
5693 bool Bound() const override {
5694 return (expr_->Bound() || expr_->Min() >= value_);
5695 }
5696
5697 std::string name() const override {
5698 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5699 }
5700
5701 std::string DebugString() const override {
5702 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5703 value_);
5704 }
5705
5706 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5707
5708 void Accept(ModelVisitor* const visitor) const override {
5709 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5710 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5711 expr_);
5712 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5713 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5714 }
5715
5716 private:
5717 IntExpr* const expr_;
5718 const int64_t value_;
5719};
5720
5721// ----- Max(expr, expr) -----
5722
5723class MaxIntExpr : public BaseIntExpr {
5724 public:
5725 MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5726 : BaseIntExpr(s), left_(l), right_(r) {}
5727
5728 ~MaxIntExpr() override {}
5729
5730 int64_t Min() const override { return std::max(left_->Min(), right_->Min()); }
5731
5732 void SetMin(int64_t m) override {
5733 if (left_->Max() < m) {
5734 right_->SetMin(m);
5735 } else {
5736 if (right_->Max() < m) {
5737 left_->SetMin(m);
5738 }
5739 }
5740 }
5741
5742 int64_t Max() const override { return std::max(left_->Max(), right_->Max()); }
5743
5744 void SetMax(int64_t m) override {
5745 left_->SetMax(m);
5746 right_->SetMax(m);
5747 }
5748
5749 std::string name() const override {
5750 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5751 }
5752
5753 std::string DebugString() const override {
5754 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5755 right_->DebugString());
5756 }
5757
5758 void WhenRange(Demon* d) override {
5759 left_->WhenRange(d);
5760 right_->WhenRange(d);
5761 }
5762
5763 void Accept(ModelVisitor* const visitor) const override {
5764 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5765 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5766 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5767 right_);
5768 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5769 }
5770
5771 private:
5772 IntExpr* const left_;
5773 IntExpr* const right_;
5774};
5775
5776// ----- Max(expr, constant) -----
5777
5778class MaxCstIntExpr : public BaseIntExpr {
5779 public:
5780 MaxCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5781 : BaseIntExpr(s), expr_(e), value_(v) {}
5782
5783 ~MaxCstIntExpr() override {}
5784
5785 int64_t Min() const override { return std::max(expr_->Min(), value_); }
5786
5787 void SetMin(int64_t m) override {
5788 if (value_ < m) {
5789 expr_->SetMin(m);
5790 }
5791 }
5792
5793 int64_t Max() const override { return std::max(expr_->Max(), value_); }
5794
5795 void SetMax(int64_t m) override {
5796 if (m < value_) {
5797 solver()->Fail();
5798 }
5799 expr_->SetMax(m);
5800 }
5801
5802 bool Bound() const override {
5803 return (expr_->Bound() || expr_->Max() <= value_);
5804 }
5805
5806 std::string name() const override {
5807 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5808 }
5809
5810 std::string DebugString() const override {
5811 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5812 value_);
5813 }
5814
5815 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5816
5817 void Accept(ModelVisitor* const visitor) const override {
5818 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5819 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5820 expr_);
5821 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5822 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5823 }
5824
5825 private:
5826 IntExpr* const expr_;
5827 const int64_t value_;
5828};
5829
5830// ----- Convex Piecewise -----
5831
5832// This class is a very simple convex piecewise linear function. The
5833// argument of the function is the expression. Between early_date and
5834// late_date, the value of the function is 0. Before early date, it
5835// is affine and the cost is early_cost * (early_date - x). After
5836// late_date, the cost is late_cost * (x - late_date).
5837
5838class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5839 public:
5840 SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64_t ec,
5841 int64_t ed, int64_t ld, int64_t lc)
5842 : BaseIntExpr(s),
5843 expr_(e),
5844 early_cost_(ec),
5845 early_date_(ec == 0 ? std::numeric_limits<int64_t>::min() : ed),
5846 late_date_(lc == 0 ? std::numeric_limits<int64_t>::max() : ld),
5847 late_cost_(lc) {
5848 DCHECK_GE(ec, int64_t{0});
5849 DCHECK_GE(lc, int64_t{0});
5850 DCHECK_GE(ld, ed);
5851
5852 // If the penalty is 0, we can push the "confort zone or zone
5853 // of no cost towards infinity.
5854 }
5855
5856 ~SimpleConvexPiecewiseExpr() override {}
5857
5858 int64_t Min() const override {
5859 const int64_t vmin = expr_->Min();
5860 const int64_t vmax = expr_->Max();
5861 if (vmin >= late_date_) {
5862 return (vmin - late_date_) * late_cost_;
5863 } else if (vmax <= early_date_) {
5864 return (early_date_ - vmax) * early_cost_;
5865 } else {
5866 return 0LL;
5867 }
5868 }
5869
5870 void SetMin(int64_t m) override {
5871 if (m <= 0) {
5872 return;
5873 }
5874 int64_t vmin = 0;
5875 int64_t vmax = 0;
5876 expr_->Range(&vmin, &vmax);
5877
5878 const int64_t rb =
5879 (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5880 const int64_t lb =
5881 (early_cost_ == 0 ? vmin
5882 : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5883
5884 if (expr_->IsVar()) {
5885 expr_->Var()->RemoveInterval(lb, rb);
5886 }
5887 }
5888
5889 int64_t Max() const override {
5890 const int64_t vmin = expr_->Min();
5891 const int64_t vmax = expr_->Max();
5892 const int64_t mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5893 const int64_t ml =
5894 vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5895 return std::max(mr, ml);
5896 }
5897
5898 void SetMax(int64_t m) override {
5899 if (m < 0) {
5900 solver()->Fail();
5901 }
5902 if (late_cost_ != 0LL) {
5903 const int64_t rb = late_date_ + PosIntDivDown(m, late_cost_);
5904 if (early_cost_ != 0LL) {
5905 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5906 expr_->SetRange(lb, rb);
5907 } else {
5908 expr_->SetMax(rb);
5909 }
5910 } else {
5911 if (early_cost_ != 0LL) {
5912 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5913 expr_->SetMin(lb);
5914 }
5915 }
5916 }
5917
5918 std::string name() const override {
5919 return absl::StrFormat(
5920 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5921 expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5922 }
5923
5924 std::string DebugString() const override {
5925 return absl::StrFormat(
5926 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5927 expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5928 }
5929
5930 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5931
5932 void Accept(ModelVisitor* const visitor) const override {
5933 visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5934 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5935 expr_);
5936 visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5937 early_cost_);
5938 visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5939 early_date_);
5940 visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5941 visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5942 visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5943 }
5944
5945 private:
5946 IntExpr* const expr_;
5947 const int64_t early_cost_;
5948 const int64_t early_date_;
5949 const int64_t late_date_;
5950 const int64_t late_cost_;
5951};
5952
5953// ----- Semi Continuous -----
5954
5955class SemiContinuousExpr : public BaseIntExpr {
5956 public:
5957 SemiContinuousExpr(Solver* const s, IntExpr* const e, int64_t fixed_charge,
5958 int64_t step)
5959 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5960 DCHECK_GE(fixed_charge, int64_t{0});
5961 DCHECK_GT(step, int64_t{0});
5962 }
5963
5964 ~SemiContinuousExpr() override {}
5965
5966 int64_t Value(int64_t x) const {
5967 if (x <= 0) {
5968 return 0;
5969 } else {
5970 return CapAdd(fixed_charge_, CapProd(x, step_));
5971 }
5972 }
5973
5974 int64_t Min() const override { return Value(expr_->Min()); }
5975
5976 void SetMin(int64_t m) override {
5977 if (m >= CapAdd(fixed_charge_, step_)) {
5978 const int64_t y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5979 expr_->SetMin(y);
5980 } else if (m > 0) {
5981 expr_->SetMin(1);
5982 }
5983 }
5984
5985 int64_t Max() const override { return Value(expr_->Max()); }
5986
5987 void SetMax(int64_t m) override {
5988 if (m < 0) {
5989 solver()->Fail();
5990 }
5991 if (m == std::numeric_limits<int64_t>::max()) {
5992 return;
5993 }
5994 if (m < CapAdd(fixed_charge_, step_)) {
5995 expr_->SetMax(0);
5996 } else {
5997 const int64_t y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5998 expr_->SetMax(y);
5999 }
6000 }
6001
6002 std::string name() const override {
6003 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
6004 expr_->name(), fixed_charge_, step_);
6005 }
6006
6007 std::string DebugString() const override {
6008 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
6009 expr_->DebugString(), fixed_charge_, step_);
6010 }
6011
6012 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6013
6014 void Accept(ModelVisitor* const visitor) const override {
6015 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6016 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6017 expr_);
6018 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6019 fixed_charge_);
6020 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
6021 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6022 }
6023
6024 private:
6025 IntExpr* const expr_;
6026 const int64_t fixed_charge_;
6027 const int64_t step_;
6028};
6029
6030class SemiContinuousStepOneExpr : public BaseIntExpr {
6031 public:
6032 SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
6033 int64_t fixed_charge)
6034 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6035 DCHECK_GE(fixed_charge, int64_t{0});
6036 }
6037
6038 ~SemiContinuousStepOneExpr() override {}
6039
6040 int64_t Value(int64_t x) const {
6041 if (x <= 0) {
6042 return 0;
6043 } else {
6044 return fixed_charge_ + x;
6045 }
6046 }
6047
6048 int64_t Min() const override { return Value(expr_->Min()); }
6049
6050 void SetMin(int64_t m) override {
6051 if (m >= fixed_charge_ + 1) {
6052 expr_->SetMin(m - fixed_charge_);
6053 } else if (m > 0) {
6054 expr_->SetMin(1);
6055 }
6056 }
6057
6058 int64_t Max() const override { return Value(expr_->Max()); }
6059
6060 void SetMax(int64_t m) override {
6061 if (m < 0) {
6062 solver()->Fail();
6063 }
6064 if (m < fixed_charge_ + 1) {
6065 expr_->SetMax(0);
6066 } else {
6067 expr_->SetMax(m - fixed_charge_);
6068 }
6069 }
6070
6071 std::string name() const override {
6072 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6073 expr_->name(), fixed_charge_);
6074 }
6075
6076 std::string DebugString() const override {
6077 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6078 expr_->DebugString(), fixed_charge_);
6079 }
6080
6081 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6082
6083 void Accept(ModelVisitor* const visitor) const override {
6084 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6085 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6086 expr_);
6087 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6088 fixed_charge_);
6089 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6090 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6091 }
6092
6093 private:
6094 IntExpr* const expr_;
6095 const int64_t fixed_charge_;
6096};
6097
6098class SemiContinuousStepZeroExpr : public BaseIntExpr {
6099 public:
6100 SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6101 int64_t fixed_charge)
6102 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6103 DCHECK_GT(fixed_charge, int64_t{0});
6104 }
6105
6106 ~SemiContinuousStepZeroExpr() override {}
6107
6108 int64_t Value(int64_t x) const {
6109 if (x <= 0) {
6110 return 0;
6111 } else {
6112 return fixed_charge_;
6113 }
6114 }
6115
6116 int64_t Min() const override { return Value(expr_->Min()); }
6117
6118 void SetMin(int64_t m) override {
6119 if (m >= fixed_charge_) {
6120 solver()->Fail();
6121 } else if (m > 0) {
6122 expr_->SetMin(1);
6123 }
6124 }
6125
6126 int64_t Max() const override { return Value(expr_->Max()); }
6127
6128 void SetMax(int64_t m) override {
6129 if (m < 0) {
6130 solver()->Fail();
6131 }
6132 if (m < fixed_charge_) {
6133 expr_->SetMax(0);
6134 }
6135 }
6136
6137 std::string name() const override {
6138 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6139 expr_->name(), fixed_charge_);
6140 }
6141
6142 std::string DebugString() const override {
6143 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6144 expr_->DebugString(), fixed_charge_);
6145 }
6146
6147 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6148
6149 void Accept(ModelVisitor* const visitor) const override {
6150 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6151 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6152 expr_);
6153 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6154 fixed_charge_);
6155 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6156 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6157 }
6158
6159 private:
6160 IntExpr* const expr_;
6161 const int64_t fixed_charge_;
6162};
6163
6164// This constraints links an expression and the variable it is casted into
6165class LinkExprAndVar : public CastConstraint {
6166 public:
6167 LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6168 : CastConstraint(s, var), expr_(expr) {}
6169
6170 ~LinkExprAndVar() override {}
6171
6172 void Post() override {
6173 Solver* const s = solver();
6174 Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6175 expr_->WhenRange(d);
6176 target_var_->WhenRange(d);
6177 }
6178
6179 void InitialPropagate() override {
6180 expr_->SetRange(target_var_->Min(), target_var_->Max());
6181 int64_t l, u;
6182 expr_->Range(&l, &u);
6183 target_var_->SetRange(l, u);
6184 }
6185
6186 std::string DebugString() const override {
6187 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6188 target_var_->DebugString());
6189 }
6190
6191 void Accept(ModelVisitor* const visitor) const override {
6192 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6193 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6194 expr_);
6195 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6196 target_var_);
6197 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6198 }
6199
6200 private:
6201 IntExpr* const expr_;
6202};
6203
6204// ----- Conditional Expression -----
6205
6206class ExprWithEscapeValue : public BaseIntExpr {
6207 public:
6208 ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6209 int64_t unperformed_value)
6210 : BaseIntExpr(s),
6211 condition_(c),
6212 expression_(e),
6213 unperformed_value_(unperformed_value) {}
6214
6215 // This type is neither copyable nor movable.
6216 ExprWithEscapeValue(const ExprWithEscapeValue&) = delete;
6217 ExprWithEscapeValue& operator=(const ExprWithEscapeValue&) = delete;
6218
6219 ~ExprWithEscapeValue() override {}
6220
6221 int64_t Min() const override {
6222 if (condition_->Min() == 1) {
6223 return expression_->Min();
6224 } else if (condition_->Max() == 1) {
6225 return std::min(unperformed_value_, expression_->Min());
6226 } else {
6227 return unperformed_value_;
6228 }
6229 }
6230
6231 void SetMin(int64_t m) override {
6232 if (m > unperformed_value_) {
6233 condition_->SetValue(1);
6234 expression_->SetMin(m);
6235 } else if (condition_->Min() == 1) {
6236 expression_->SetMin(m);
6237 } else if (m > expression_->Max()) {
6238 condition_->SetValue(0);
6239 }
6240 }
6241
6242 int64_t Max() const override {
6243 if (condition_->Min() == 1) {
6244 return expression_->Max();
6245 } else if (condition_->Max() == 1) {
6246 return std::max(unperformed_value_, expression_->Max());
6247 } else {
6248 return unperformed_value_;
6249 }
6250 }
6251
6252 void SetMax(int64_t m) override {
6253 if (m < unperformed_value_) {
6254 condition_->SetValue(1);
6255 expression_->SetMax(m);
6256 } else if (condition_->Min() == 1) {
6257 expression_->SetMax(m);
6258 } else if (m < expression_->Min()) {
6259 condition_->SetValue(0);
6260 }
6261 }
6262
6263 void SetRange(int64_t mi, int64_t ma) override {
6264 if (ma < unperformed_value_ || mi > unperformed_value_) {
6265 condition_->SetValue(1);
6266 expression_->SetRange(mi, ma);
6267 } else if (condition_->Min() == 1) {
6268 expression_->SetRange(mi, ma);
6269 } else if (ma < expression_->Min() || mi > expression_->Max()) {
6270 condition_->SetValue(0);
6271 }
6272 }
6273
6274 void SetValue(int64_t v) override {
6275 if (v != unperformed_value_) {
6276 condition_->SetValue(1);
6277 expression_->SetValue(v);
6278 } else if (condition_->Min() == 1) {
6279 expression_->SetValue(v);
6280 } else if (v < expression_->Min() || v > expression_->Max()) {
6281 condition_->SetValue(0);
6282 }
6283 }
6284
6285 bool Bound() const override {
6286 return condition_->Max() == 0 || expression_->Bound();
6287 }
6288
6289 void WhenRange(Demon* d) override {
6290 expression_->WhenRange(d);
6291 condition_->WhenBound(d);
6292 }
6293
6294 std::string DebugString() const override {
6295 return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6296 condition_->DebugString(),
6297 expression_->DebugString(), unperformed_value_);
6298 }
6299
6300 void Accept(ModelVisitor* const visitor) const override {
6301 visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6302 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6303 condition_);
6304 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6305 expression_);
6306 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6307 unperformed_value_);
6308 visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6309 }
6310
6311 private:
6312 IntVar* const condition_;
6313 IntExpr* const expression_;
6314 const int64_t unperformed_value_;
6315};
6316
6317// ----- This is a specialized case when the variable exact type is known -----
6318class LinkExprAndDomainIntVar : public CastConstraint {
6319 public:
6320 LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6321 DomainIntVar* const var)
6322 : CastConstraint(s, var),
6323 expr_(expr),
6324 cached_min_(std::numeric_limits<int64_t>::min()),
6325 cached_max_(std::numeric_limits<int64_t>::max()),
6326 fail_stamp_(uint64_t{0}) {}
6327
6328 ~LinkExprAndDomainIntVar() override {}
6329
6330 DomainIntVar* var() const {
6331 return reinterpret_cast<DomainIntVar*>(target_var_);
6332 }
6333
6334 void Post() override {
6335 Solver* const s = solver();
6336 Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6337 expr_->WhenRange(d);
6338 Demon* const target_var_demon = MakeConstraintDemon0(
6339 solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6340 target_var_->WhenRange(target_var_demon);
6341 }
6342
6343 void InitialPropagate() override {
6344 expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6345 expr_->Range(&cached_min_, &cached_max_);
6346 var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6347 }
6348
6349 void Propagate() {
6350 if (var()->min_.Value() > cached_min_ ||
6351 var()->max_.Value() < cached_max_ ||
6352 solver()->fail_stamp() != fail_stamp_) {
6353 InitialPropagate();
6354 fail_stamp_ = solver()->fail_stamp();
6355 }
6356 }
6357
6358 std::string DebugString() const override {
6359 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6360 target_var_->DebugString());
6361 }
6362
6363 void Accept(ModelVisitor* const visitor) const override {
6364 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6365 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6366 expr_);
6367 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6368 target_var_);
6369 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6370 }
6371
6372 private:
6373 IntExpr* const expr_;
6374 int64_t cached_min_;
6375 int64_t cached_max_;
6376 uint64_t fail_stamp_;
6377};
6378} // namespace
6379
6380// ----- Misc -----
6381
6382IntVarIterator* BooleanVar::MakeHoleIterator(bool reversible) const {
6383 return CondRevAlloc(solver(), reversible, new EmptyIterator());
6384}
6385IntVarIterator* BooleanVar::MakeDomainIterator(bool reversible) const {
6386 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
6387}
6388
6389// ----- API -----
6390
6392 DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6393 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6394 dvar->CleanInProcess();
6395}
6396
6397Constraint* SetIsEqual(IntVar* const var, absl::Span<const int64_t> values,
6398 const std::vector<IntVar*>& vars) {
6399 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6400 CHECK(dvar != nullptr);
6401 return dvar->SetIsEqual(values, vars);
6402}
6403
6405 absl::Span<const int64_t> values,
6406 const std::vector<IntVar*>& vars) {
6407 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6408 CHECK(dvar != nullptr);
6409 return dvar->SetIsGreaterOrEqual(values, vars);
6410}
6411
6413 DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6414 BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6415 boolean_var->RestoreValue();
6416}
6417
6418// ----- API -----
6419
6420IntVar* Solver::MakeIntVar(int64_t min, int64_t max, const std::string& name) {
6421 if (min == max) {
6422 return MakeIntConst(min, name);
6423 }
6424 if (min == 0 && max == 1) {
6425 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6426 } else if (CapSub(max, min) == 1) {
6427 const std::string inner_name = "inner_" + name;
6428 return RegisterIntVar(
6429 MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6430 ->VarWithName(name));
6431 } else {
6432 return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6433 }
6434}
6435
6436IntVar* Solver::MakeIntVar(int64_t min, int64_t max) {
6437 return MakeIntVar(min, max, "");
6438}
6439
6440IntVar* Solver::MakeBoolVar(const std::string& name) {
6441 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6442}
6443
6444IntVar* Solver::MakeBoolVar() {
6445 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6446}
6447
6448IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values,
6449 const std::string& name) {
6450 DCHECK(!values.empty());
6451 // Fast-track the case where we have a single value.
6452 if (values.size() == 1) return MakeIntConst(values[0], name);
6453 // Sort and remove duplicates.
6454 std::vector<int64_t> unique_sorted_values = values;
6455 gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6456 // Case when we have a single value, after clean-up.
6457 if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6458 // Case when the values are a dense interval of integers.
6459 if (unique_sorted_values.size() ==
6460 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6461 return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6462 name);
6463 }
6464 // Compute the GCD: if it's not 1, we can express the variable's domain as
6465 // the product of the GCD and of a domain with smaller values.
6466 int64_t gcd = 0;
6467 for (const int64_t v : unique_sorted_values) {
6468 if (gcd == 0) {
6469 gcd = std::abs(v);
6470 } else {
6471 gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6472 }
6473 if (gcd == 1) {
6474 // If it's 1, though, we can't do anything special, so we
6475 // immediately return a new DomainIntVar.
6476 return RegisterIntVar(
6477 RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6478 }
6479 }
6480 DCHECK_GT(gcd, 1);
6481 for (int64_t& v : unique_sorted_values) {
6482 DCHECK_EQ(0, v % gcd);
6483 v /= gcd;
6484 }
6485 const std::string new_name = name.empty() ? "" : "inner_" + name;
6486 // Catch the case where the divided values are a dense set of integers.
6487 IntVar* inner_intvar = nullptr;
6488 if (unique_sorted_values.size() ==
6489 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6490 inner_intvar = MakeIntVar(unique_sorted_values.front(),
6491 unique_sorted_values.back(), new_name);
6492 } else {
6493 inner_intvar = RegisterIntVar(
6494 RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6495 }
6496 return MakeProd(inner_intvar, gcd)->Var();
6497}
6498
6499IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values) {
6500 return MakeIntVar(values, "");
6501}
6502
6503IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6504 const std::string& name) {
6505 return MakeIntVar(ToInt64Vector(values), name);
6506}
6507
6508IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6509 return MakeIntVar(values, "");
6510}
6511
6512IntVar* Solver::MakeIntConst(int64_t val, const std::string& name) {
6513 // If IntConst is going to be named after its creation,
6514 // cp_share_int_consts should be set to false otherwise names can potentially
6515 // be overwritten.
6516 if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6517 val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6518 return cached_constants_[val - MIN_CACHED_INT_CONST];
6519 }
6520 return RevAlloc(new IntConst(this, val, name));
6521}
6522
6523IntVar* Solver::MakeIntConst(int64_t val) { return MakeIntConst(val, ""); }
6524
6525// ----- Int Var and associated methods -----
6526
6527namespace {
6528std::string IndexedName(absl::string_view prefix, int index, int max_index) {
6529#if 0
6530#if defined(_MSC_VER)
6531 const int digits = max_index > 0 ?
6532 static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6533 1;
6534#else
6535 const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6536#endif
6537 return absl::StrFormat("%s%0*d", prefix, digits, index);
6538#else
6539 return absl::StrCat(prefix, index);
6540#endif
6541}
6542} // namespace
6543
6544void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6545 const std::string& name,
6546 std::vector<IntVar*>* vars) {
6547 for (int i = 0; i < var_count; ++i) {
6548 vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6549 }
6550}
6551
6552void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6553 std::vector<IntVar*>* vars) {
6554 for (int i = 0; i < var_count; ++i) {
6555 vars->push_back(MakeIntVar(vmin, vmax));
6556 }
6557}
6558
6559IntVar** Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6560 const std::string& name) {
6561 IntVar** vars = new IntVar*[var_count];
6562 for (int i = 0; i < var_count; ++i) {
6563 vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6564 }
6565 return vars;
6566}
6567
6568void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6569 std::vector<IntVar*>* vars) {
6570 for (int i = 0; i < var_count; ++i) {
6571 vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6572 }
6573}
6574
6575void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6576 for (int i = 0; i < var_count; ++i) {
6577 vars->push_back(MakeBoolVar());
6578 }
6579}
6580
6581IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6582 IntVar** vars = new IntVar*[var_count];
6583 for (int i = 0; i < var_count; ++i) {
6584 vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6585 }
6586 return vars;
6587}
6588
6589void Solver::InitCachedIntConstants() {
6590 for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6591 cached_constants_[i - MIN_CACHED_INT_CONST] =
6592 RevAlloc(new IntConst(this, i, "")); // note the empty name
6593 }
6594}
6595
6596IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6597 CHECK_EQ(this, left->solver());
6598 CHECK_EQ(this, right->solver());
6599 if (right->Bound()) {
6600 return MakeSum(left, right->Min());
6601 }
6602 if (left->Bound()) {
6603 return MakeSum(right, left->Min());
6604 }
6605 if (left == right) {
6606 return MakeProd(left, 2);
6607 }
6608 IntExpr* cache = model_cache_->FindExprExprExpression(
6609 left, right, ModelCache::EXPR_EXPR_SUM);
6610 if (cache == nullptr) {
6611 cache = model_cache_->FindExprExprExpression(right, left,
6612 ModelCache::EXPR_EXPR_SUM);
6613 }
6614 if (cache != nullptr) {
6615 return cache;
6616 } else {
6617 IntExpr* const result =
6618 AddOverflows(left->Max(), right->Max()) ||
6619 AddOverflows(left->Min(), right->Min())
6620 ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6621 : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6622 model_cache_->InsertExprExprExpression(result, left, right,
6623 ModelCache::EXPR_EXPR_SUM);
6624 return result;
6625 }
6626}
6627
6628IntExpr* Solver::MakeSum(IntExpr* const expr, int64_t value) {
6629 CHECK_EQ(this, expr->solver());
6630 if (expr->Bound()) {
6631 return MakeIntConst(CapAdd(expr->Min(), value));
6632 }
6633 if (value == 0) {
6634 return expr;
6635 }
6636 IntExpr* result = Cache()->FindExprConstantExpression(
6637 expr, value, ModelCache::EXPR_CONSTANT_SUM);
6638 if (result == nullptr) {
6639 if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6640 !AddOverflows(value, expr->Min())) {
6641 IntVar* const var = expr->Var();
6642 switch (var->VarType()) {
6643 case DOMAIN_INT_VAR: {
6644 result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6645 this, reinterpret_cast<DomainIntVar*>(var), value)));
6646 break;
6647 }
6648 case CONST_VAR: {
6649 result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6650 break;
6651 }
6652 case VAR_ADD_CST: {
6653 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6654 IntVar* const sub_var = add_var->SubVar();
6655 const int64_t new_constant = value + add_var->Constant();
6656 if (new_constant == 0) {
6657 result = sub_var;
6658 } else {
6659 if (sub_var->VarType() == DOMAIN_INT_VAR) {
6660 DomainIntVar* const dvar =
6661 reinterpret_cast<DomainIntVar*>(sub_var);
6662 result = RegisterIntExpr(
6663 RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6664 } else {
6665 result = RegisterIntExpr(
6666 RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6667 }
6668 }
6669 break;
6670 }
6671 case CST_SUB_VAR: {
6672 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6673 IntVar* const sub_var = add_var->SubVar();
6674 const int64_t new_constant = value + add_var->Constant();
6675 result = RegisterIntExpr(
6676 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6677 break;
6678 }
6679 case OPP_VAR: {
6680 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6681 IntVar* const sub_var = add_var->SubVar();
6682 result =
6683 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6684 break;
6685 }
6686 default:
6687 result =
6688 RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6689 }
6690 } else {
6691 result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6692 }
6693 Cache()->InsertExprConstantExpression(result, expr, value,
6694 ModelCache::EXPR_CONSTANT_SUM);
6695 }
6696 return result;
6697}
6698
6699IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6700 CHECK_EQ(this, left->solver());
6701 CHECK_EQ(this, right->solver());
6702 if (left->Bound()) {
6703 return MakeDifference(left->Min(), right);
6704 }
6705 if (right->Bound()) {
6706 return MakeSum(left, -right->Min());
6707 }
6708 IntExpr* sub_left = nullptr;
6709 IntExpr* sub_right = nullptr;
6710 int64_t left_coef = 1;
6711 int64_t right_coef = 1;
6712 if (IsProduct(left, &sub_left, &left_coef) &&
6713 IsProduct(right, &sub_right, &right_coef)) {
6714 const int64_t abs_gcd =
6715 MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6716 if (abs_gcd != 0 && abs_gcd != 1) {
6717 return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6718 MakeProd(sub_right, right_coef / abs_gcd)),
6719 abs_gcd);
6720 }
6721 }
6722
6723 IntExpr* result = Cache()->FindExprExprExpression(
6724 left, right, ModelCache::EXPR_EXPR_DIFFERENCE);
6725 if (result == nullptr) {
6726 if (!SubOverflows(left->Min(), right->Max()) &&
6727 !SubOverflows(left->Max(), right->Min())) {
6728 result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6729 } else {
6730 result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6731 }
6732 Cache()->InsertExprExprExpression(result, left, right,
6733 ModelCache::EXPR_EXPR_DIFFERENCE);
6734 }
6735 return result;
6736}
6737
6738// warning: this is 'value - expr'.
6739IntExpr* Solver::MakeDifference(int64_t value, IntExpr* const expr) {
6740 CHECK_EQ(this, expr->solver());
6741 if (expr->Bound()) {
6742 return MakeIntConst(value - expr->Min());
6743 }
6744 if (value == 0) {
6745 return MakeOpposite(expr);
6746 }
6747 IntExpr* result = Cache()->FindExprConstantExpression(
6748 expr, value, ModelCache::EXPR_CONSTANT_DIFFERENCE);
6749 if (result == nullptr) {
6750 if (expr->IsVar() && expr->Min() != std::numeric_limits<int64_t>::min() &&
6751 !SubOverflows(value, expr->Min()) &&
6752 !SubOverflows(value, expr->Max())) {
6753 IntVar* const var = expr->Var();
6754 switch (var->VarType()) {
6755 case VAR_ADD_CST: {
6756 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6757 IntVar* const sub_var = add_var->SubVar();
6758 const int64_t new_constant = value - add_var->Constant();
6759 if (new_constant == 0) {
6760 result = sub_var;
6761 } else {
6762 result = RegisterIntExpr(
6763 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6764 }
6765 break;
6766 }
6767 case CST_SUB_VAR: {
6768 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6769 IntVar* const sub_var = add_var->SubVar();
6770 const int64_t new_constant = value - add_var->Constant();
6771 result = MakeSum(sub_var, new_constant);
6772 break;
6773 }
6774 case OPP_VAR: {
6775 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6776 IntVar* const sub_var = add_var->SubVar();
6777 result = MakeSum(sub_var, value);
6778 break;
6779 }
6780 default:
6781 result =
6782 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6783 }
6784 } else {
6785 result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6786 }
6787 Cache()->InsertExprConstantExpression(result, expr, value,
6788 ModelCache::EXPR_CONSTANT_DIFFERENCE);
6789 }
6790 return result;
6791}
6792
6793IntExpr* Solver::MakeOpposite(IntExpr* const expr) {
6794 CHECK_EQ(this, expr->solver());
6795 if (expr->Bound()) {
6796 return MakeIntConst(CapOpp(expr->Min()));
6797 }
6798 IntExpr* result =
6799 Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6800 if (result == nullptr) {
6801 if (expr->IsVar()) {
6802 result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6803 } else {
6804 result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6805 }
6806 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6807 }
6808 return result;
6809}
6810
6811IntExpr* Solver::MakeProd(IntExpr* const expr, int64_t value) {
6812 CHECK_EQ(this, expr->solver());
6813 IntExpr* result = Cache()->FindExprConstantExpression(
6814 expr, value, ModelCache::EXPR_CONSTANT_PROD);
6815 if (result != nullptr) {
6816 return result;
6817 } else {
6818 IntExpr* m_expr = nullptr;
6819 int64_t coefficient = 1;
6820 if (IsProduct(expr, &m_expr, &coefficient)) {
6822 } else {
6823 m_expr = expr;
6825 }
6826 if (m_expr->Bound()) {
6827 return MakeIntConst(CapProd(coefficient, m_expr->Min()));
6828 } else if (coefficient == 1) {
6829 return m_expr;
6830 } else if (coefficient == -1) {
6831 return MakeOpposite(m_expr);
6832 } else if (coefficient > 0) {
6833 if (m_expr->Max() > std::numeric_limits<int64_t>::max() / coefficient ||
6834 m_expr->Min() < std::numeric_limits<int64_t>::min() / coefficient) {
6835 result = RegisterIntExpr(
6836 RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6837 } else {
6838 result = RegisterIntExpr(
6839 RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6840 }
6841 } else if (coefficient == 0) {
6842 result = MakeIntConst(0);
6843 } else { // coefficient < 0.
6844 result = RegisterIntExpr(
6845 RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6846 }
6847 if (m_expr->IsVar() &&
6848 !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6849 result = result->Var();
6850 }
6851 Cache()->InsertExprConstantExpression(result, expr, value,
6852 ModelCache::EXPR_CONSTANT_PROD);
6853 return result;
6854 }
6855}
6856
6857namespace {
6858void ExtractPower(IntExpr** const expr, int64_t* const exponant) {
6859 if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6860 BasePower* const power = dynamic_cast<BasePower*>(*expr);
6861 *expr = power->expr();
6862 *exponant = power->exponant();
6863 }
6864 if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6865 IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6866 *expr = power->expr();
6867 *exponant = 2;
6868 }
6869 if ((*expr)->IsVar()) {
6870 IntVar* const var = (*expr)->Var();
6871 IntExpr* const sub = var->solver()->CastExpression(var);
6872 if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6873 BasePower* const power = dynamic_cast<BasePower*>(sub);
6874 *expr = power->expr();
6875 *exponant = power->exponant();
6876 }
6877 if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6878 IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6879 *expr = power->expr();
6880 *exponant = 2;
6881 }
6882 }
6883}
6884
6885void ExtractProduct(IntExpr** const expr, int64_t* const coefficient,
6886 bool* modified) {
6887 if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6888 TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6889 *coefficient *= left_prod->Constant();
6890 *expr = left_prod->SubVar();
6891 *modified = true;
6892 } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6893 TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6894 *coefficient *= left_prod->Constant();
6895 *expr = left_prod->Expr();
6896 *modified = true;
6897 }
6898}
6899} // namespace
6900
6901IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6902 if (left->Bound()) {
6903 return MakeProd(right, left->Min());
6904 }
6905
6906 if (right->Bound()) {
6907 return MakeProd(left, right->Min());
6908 }
6909
6910 // ----- Discover squares and powers -----
6911
6912 IntExpr* m_left = left;
6913 IntExpr* m_right = right;
6914 int64_t left_exponant = 1;
6915 int64_t right_exponant = 1;
6916 ExtractPower(&m_left, &left_exponant);
6917 ExtractPower(&m_right, &right_exponant);
6918
6919 if (m_left == m_right) {
6920 return MakePower(m_left, left_exponant + right_exponant);
6921 }
6922
6923 // ----- Discover nested products -----
6924
6925 m_left = left;
6926 m_right = right;
6927 int64_t coefficient = 1;
6928 bool modified = false;
6929
6930 ExtractProduct(&m_left, &coefficient, &modified);
6931 ExtractProduct(&m_right, &coefficient, &modified);
6932 if (modified) {
6933 return MakeProd(MakeProd(m_left, m_right), coefficient);
6934 }
6935
6936 // ----- Standard build -----
6937
6938 CHECK_EQ(this, left->solver());
6939 CHECK_EQ(this, right->solver());
6940 IntExpr* result = model_cache_->FindExprExprExpression(
6941 left, right, ModelCache::EXPR_EXPR_PROD);
6942 if (result == nullptr) {
6943 result = model_cache_->FindExprExprExpression(right, left,
6944 ModelCache::EXPR_EXPR_PROD);
6945 }
6946 if (result != nullptr) {
6947 return result;
6948 }
6949 if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6950 if (right->Min() >= 0) {
6951 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6952 this, reinterpret_cast<BooleanVar*>(left), right)));
6953 } else {
6954 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6955 this, reinterpret_cast<BooleanVar*>(left), right)));
6956 }
6957 } else if (right->IsVar() &&
6958 reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6959 if (left->Min() >= 0) {
6960 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6961 this, reinterpret_cast<BooleanVar*>(right), left)));
6962 } else {
6963 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6964 this, reinterpret_cast<BooleanVar*>(right), left)));
6965 }
6966 } else if (left->Min() >= 0 && right->Min() >= 0) {
6967 if (CapProd(left->Max(), right->Max()) ==
6968 std::numeric_limits<int64_t>::max()) { // Potential overflow.
6969 result =
6970 RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6971 } else {
6972 result =
6973 RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6974 }
6975 } else {
6976 result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6977 }
6978 model_cache_->InsertExprExprExpression(result, left, right,
6979 ModelCache::EXPR_EXPR_PROD);
6980 return result;
6981}
6982
6983IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6984 CHECK(numerator != nullptr);
6985 CHECK(denominator != nullptr);
6986 if (denominator->Bound()) {
6987 return MakeDiv(numerator, denominator->Min());
6988 }
6989 IntExpr* result = model_cache_->FindExprExprExpression(
6990 numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6991 if (result != nullptr) {
6992 return result;
6993 }
6994
6995 if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6996 AddConstraint(MakeNonEquality(denominator, 0));
6997 }
6998
6999 if (denominator->Min() >= 0) {
7000 if (numerator->Min() >= 0) {
7001 result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
7002 } else {
7003 result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
7004 }
7005 } else if (denominator->Max() <= 0) {
7006 if (numerator->Max() <= 0) {
7007 result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
7008 MakeOpposite(denominator)));
7009 } else {
7010 result = MakeOpposite(RevAlloc(
7011 new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
7012 }
7013 } else {
7014 result = RevAlloc(new DivIntExpr(this, numerator, denominator));
7015 }
7016 model_cache_->InsertExprExprExpression(result, numerator, denominator,
7017 ModelCache::EXPR_EXPR_DIV);
7018 return result;
7019}
7020
7021IntExpr* Solver::MakeDiv(IntExpr* const expr, int64_t value) {
7022 CHECK(expr != nullptr);
7023 CHECK_EQ(this, expr->solver());
7024 if (expr->Bound()) {
7025 return MakeIntConst(expr->Min() / value);
7026 } else if (value == 1) {
7027 return expr;
7028 } else if (value == -1) {
7029 return MakeOpposite(expr);
7030 } else if (value > 0) {
7031 return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
7032 } else if (value == 0) {
7033 LOG(FATAL) << "Cannot divide by 0";
7034 return nullptr;
7035 } else {
7036 return RegisterIntExpr(
7037 MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
7038 // TODO(user) : implement special case.
7039 }
7040}
7041
7042Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
7043 if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
7044 Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
7045 }
7046 return RevAlloc(new IntAbsConstraint(this, var, abs_var));
7047}
7048
7049IntExpr* Solver::MakeAbs(IntExpr* const e) {
7050 CHECK_EQ(this, e->solver());
7051 if (e->Min() >= 0) {
7052 return e;
7053 } else if (e->Max() <= 0) {
7054 return MakeOpposite(e);
7055 }
7056 IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
7057 if (result == nullptr) {
7058 int64_t coefficient = 1;
7059 IntExpr* expr = nullptr;
7060 if (IsProduct(e, &expr, &coefficient)) {
7061 result = MakeProd(MakeAbs(expr), std::abs(coefficient));
7062 } else {
7063 result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
7064 }
7065 Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7066 }
7067 return result;
7068}
7069
7070IntExpr* Solver::MakeSquare(IntExpr* const expr) {
7071 CHECK_EQ(this, expr->solver());
7072 if (expr->Bound()) {
7073 const int64_t v = expr->Min();
7074 return MakeIntConst(v * v);
7075 }
7076 IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7077 if (result == nullptr) {
7078 if (expr->Min() >= 0) {
7079 result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7080 } else {
7081 result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7082 }
7083 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7084 }
7085 return result;
7086}
7087
7088IntExpr* Solver::MakePower(IntExpr* const expr, int64_t n) {
7089 CHECK_EQ(this, expr->solver());
7090 CHECK_GE(n, 0);
7091 if (expr->Bound()) {
7092 const int64_t v = expr->Min();
7093 if (v >= OverflowLimit(n)) { // Overflow.
7094 return MakeIntConst(std::numeric_limits<int64_t>::max());
7095 }
7096 return MakeIntConst(IntPower(v, n));
7097 }
7098 switch (n) {
7099 case 0:
7100 return MakeIntConst(1);
7101 case 1:
7102 return expr;
7103 case 2:
7104 return MakeSquare(expr);
7105 default: {
7106 IntExpr* result = nullptr;
7107 if (n % 2 == 0) { // even.
7108 if (expr->Min() >= 0) {
7109 result =
7110 RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7111 } else {
7112 result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7113 }
7114 } else {
7115 result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7116 }
7117 return result;
7118 }
7119 }
7120}
7121
7122IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7123 CHECK_EQ(this, left->solver());
7124 CHECK_EQ(this, right->solver());
7125 if (left->Bound()) {
7126 return MakeMin(right, left->Min());
7127 }
7128 if (right->Bound()) {
7129 return MakeMin(left, right->Min());
7130 }
7131 if (left->Min() >= right->Max()) {
7132 return right;
7133 }
7134 if (right->Min() >= left->Max()) {
7135 return left;
7136 }
7137 return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7138}
7139
7140IntExpr* Solver::MakeMin(IntExpr* const expr, int64_t value) {
7141 CHECK_EQ(this, expr->solver());
7142 if (value <= expr->Min()) {
7143 return MakeIntConst(value);
7144 }
7145 if (expr->Bound()) {
7146 return MakeIntConst(std::min(expr->Min(), value));
7147 }
7148 if (expr->Max() <= value) {
7149 return expr;
7150 }
7151 return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7152}
7153
7154IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7155 return MakeMin(expr, static_cast<int64_t>(value));
7156}
7157
7158IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7159 CHECK_EQ(this, left->solver());
7160 CHECK_EQ(this, right->solver());
7161 if (left->Bound()) {
7162 return MakeMax(right, left->Min());
7163 }
7164 if (right->Bound()) {
7165 return MakeMax(left, right->Min());
7166 }
7167 if (left->Min() >= right->Max()) {
7168 return left;
7169 }
7170 if (right->Min() >= left->Max()) {
7171 return right;
7172 }
7173 return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7174}
7175
7176IntExpr* Solver::MakeMax(IntExpr* const expr, int64_t value) {
7177 CHECK_EQ(this, expr->solver());
7178 if (expr->Bound()) {
7179 return MakeIntConst(std::max(expr->Min(), value));
7180 }
7181 if (value <= expr->Min()) {
7182 return expr;
7183 }
7184 if (expr->Max() <= value) {
7185 return MakeIntConst(value);
7186 }
7187 return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7188}
7189
7190IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7191 return MakeMax(expr, static_cast<int64_t>(value));
7192}
7193
7194IntExpr* Solver::MakeConvexPiecewiseExpr(IntExpr* expr, int64_t early_cost,
7195 int64_t early_date, int64_t late_date,
7196 int64_t late_cost) {
7197 return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7198 this, expr, early_cost, early_date, late_date, late_cost)));
7199}
7200
7201IntExpr* Solver::MakeSemiContinuousExpr(IntExpr* const expr,
7202 int64_t fixed_charge, int64_t step) {
7203 if (step == 0) {
7204 if (fixed_charge == 0) {
7205 return MakeIntConst(int64_t{0});
7206 } else {
7207 return RegisterIntExpr(
7208 RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7209 }
7210 } else if (step == 1) {
7211 return RegisterIntExpr(
7212 RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7213 } else {
7214 return RegisterIntExpr(
7215 RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7216 }
7217 // TODO(user) : benchmark with virtualization of
7218 // PosIntDivDown and PosIntDivUp - or function pointers.
7219}
7220
7221// ----- Piecewise Linear -----
7222
7224 public:
7226 const PiecewiseLinearFunction& f)
7227 : BaseIntExpr(solver), expr_(expr), f_(f) {}
7229 int64_t Min() const override {
7230 return f_.GetMinimum(expr_->Min(), expr_->Max());
7231 }
7232 void SetMin(int64_t m) override {
7233 const auto& range =
7234 f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7235 expr_->SetRange(range.first, range.second);
7236 }
7237
7238 int64_t Max() const override {
7239 return f_.GetMaximum(expr_->Min(), expr_->Max());
7240 }
7241
7242 void SetMax(int64_t m) override {
7243 const auto& range =
7244 f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7245 expr_->SetRange(range.first, range.second);
7246 }
7247
7248 void SetRange(int64_t l, int64_t u) override {
7249 const auto& range =
7250 f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7251 expr_->SetRange(range.first, range.second);
7252 }
7253 std::string name() const override {
7254 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7255 f_.DebugString());
7256 }
7257
7258 std::string DebugString() const override {
7259 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7260 f_.DebugString());
7261 }
7262
7263 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7264
7265 void Accept(ModelVisitor* const visitor) const override {
7266 // TODO(user): Implement visitor.
7267 }
7268
7269 private:
7270 IntExpr* const expr_;
7271 const PiecewiseLinearFunction f_;
7272};
7273
7274IntExpr* Solver::MakePiecewiseLinearExpr(IntExpr* expr,
7275 const PiecewiseLinearFunction& f) {
7276 return RegisterIntExpr(RevAlloc(new PiecewiseLinearExpr(this, expr, f)));
7277}
7278
7279// ----- Conditional Expression -----
7280
7281IntExpr* Solver::MakeConditionalExpression(IntVar* const condition,
7282 IntExpr* const expr,
7283 int64_t unperformed_value) {
7284 if (condition->Min() == 1) {
7285 return expr;
7286 } else if (condition->Max() == 0) {
7287 return MakeIntConst(unperformed_value);
7288 } else {
7289 IntExpr* cache = Cache()->FindExprExprConstantExpression(
7290 condition, expr, unperformed_value,
7291 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7292 if (cache == nullptr) {
7293 cache = RevAlloc(
7294 new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7295 Cache()->InsertExprExprConstantExpression(
7296 cache, condition, expr, unperformed_value,
7297 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7298 }
7299 return cache;
7300 }
7301}
7302
7303// ----- Modulo -----
7304
7305IntExpr* Solver::MakeModulo(IntExpr* const x, int64_t mod) {
7306 IntVar* const result =
7307 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7308 if (mod >= 0) {
7309 AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7310 } else {
7311 AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7312 }
7313 return result;
7314}
7315
7316IntExpr* Solver::MakeModulo(IntExpr* const x, IntExpr* const mod) {
7317 if (mod->Bound()) {
7318 return MakeModulo(x, mod->Min());
7319 }
7320 IntVar* const result =
7321 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7322 AddConstraint(MakeLess(result, MakeAbs(mod)));
7323 AddConstraint(MakeGreater(result, MakeOpposite(MakeAbs(mod))));
7324 return result;
7325}
7326
7327// --------- IntVar ---------
7328
7329int IntVar::VarType() const { return UNSPECIFIED; }
7330
7331void IntVar::RemoveValues(const std::vector<int64_t>& values) {
7332 // TODO(user): Check and maybe inline this code.
7333 const int size = values.size();
7334 DCHECK_GE(size, 0);
7335 switch (size) {
7336 case 0: {
7337 return;
7338 }
7339 case 1: {
7340 RemoveValue(values[0]);
7341 return;
7342 }
7343 case 2: {
7344 RemoveValue(values[0]);
7345 RemoveValue(values[1]);
7346 return;
7347 }
7348 case 3: {
7349 RemoveValue(values[0]);
7350 RemoveValue(values[1]);
7351 RemoveValue(values[2]);
7352 return;
7353 }
7354 default: {
7355 // 4 values, let's start doing some more clever things.
7356 // TODO(user) : Sort values!
7357 int start_index = 0;
7358 int64_t new_min = Min();
7359 if (values[start_index] <= new_min) {
7360 while (start_index < size - 1 &&
7361 values[start_index + 1] == values[start_index] + 1) {
7362 new_min = values[start_index + 1] + 1;
7363 start_index++;
7364 }
7365 }
7366 int end_index = size - 1;
7367 int64_t new_max = Max();
7368 if (values[end_index] >= new_max) {
7369 while (end_index > start_index + 1 &&
7370 values[end_index - 1] == values[end_index] - 1) {
7371 new_max = values[end_index - 1] - 1;
7372 end_index--;
7373 }
7374 }
7375 SetRange(new_min, new_max);
7376 for (int i = start_index; i <= end_index; ++i) {
7377 RemoveValue(values[i]);
7378 }
7379 }
7380 }
7381}
7382
7383void IntVar::Accept(ModelVisitor* const visitor) const {
7384 IntExpr* const casted = solver()->CastExpression(this);
7385 visitor->VisitIntegerVariable(this, casted);
7386}
7387
7388void IntVar::SetValues(const std::vector<int64_t>& values) {
7389 switch (values.size()) {
7390 case 0: {
7391 solver()->Fail();
7392 break;
7393 }
7394 case 1: {
7395 SetValue(values.back());
7396 break;
7397 }
7398 case 2: {
7399 if (Contains(values[0])) {
7400 if (Contains(values[1])) {
7401 const int64_t l = std::min(values[0], values[1]);
7402 const int64_t u = std::max(values[0], values[1]);
7403 SetRange(l, u);
7404 if (u > l + 1) {
7405 RemoveInterval(l + 1, u - 1);
7406 }
7407 } else {
7408 SetValue(values[0]);
7409 }
7410 } else {
7411 SetValue(values[1]);
7412 }
7413 break;
7414 }
7415 default: {
7416 // TODO(user): use a clean and safe SortedUniqueCopy() class
7417 // that uses a global, static shared (and locked) storage.
7418 // TODO(user): [optional] consider porting
7419 // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7420 // existing base/stl_util.h and using it here.
7421 // TODO(user): We could filter out values not in the var.
7422 std::vector<int64_t>& tmp = solver()->tmp_vector_;
7423 tmp.clear();
7424 tmp.insert(tmp.end(), values.begin(), values.end());
7425 std::sort(tmp.begin(), tmp.end());
7426 tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7427 const int size = tmp.size();
7428 const int64_t vmin = Min();
7429 const int64_t vmax = Max();
7430 int first = 0;
7431 int last = size - 1;
7432 if (tmp.front() > vmax || tmp.back() < vmin) {
7433 solver()->Fail();
7434 }
7435 // TODO(user) : We could find the first position >= vmin by dichotomy.
7436 while (tmp[first] < vmin || !Contains(tmp[first])) {
7437 ++first;
7438 if (first > last || tmp[first] > vmax) {
7439 solver()->Fail();
7440 }
7441 }
7442 while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7443 // Note that last >= first implies tmp[last] >= vmin.
7444 --last;
7445 }
7446 DCHECK_GE(last, first);
7447 SetRange(tmp[first], tmp[last]);
7448 while (first < last) {
7449 const int64_t start = tmp[first] + 1;
7450 const int64_t end = tmp[first + 1] - 1;
7451 if (start <= end) {
7452 RemoveInterval(start, end);
7453 }
7454 first++;
7455 }
7456 }
7457 }
7458}
7459// ---------- BaseIntExpr ---------
7460
7462 if (!var->Bound()) {
7463 if (var->VarType() == DOMAIN_INT_VAR) {
7464 DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7466 s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7467 } else {
7468 s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7469 expr);
7470 }
7471 }
7472}
7473
7474IntVar* BaseIntExpr::Var() {
7475 if (var_ == nullptr) {
7476 solver()->SaveValue(reinterpret_cast<void**>(&var_));
7477 var_ = CastToVar();
7478 }
7479 return var_;
7480}
7481
7482IntVar* BaseIntExpr::CastToVar() {
7483 int64_t vmin, vmax;
7484 Range(&vmin, &vmax);
7485 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7486 LinkVarExpr(solver(), this, var);
7487 return var;
7488}
7489
7490// Discovery methods
7491bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7492 IntExpr** const right) {
7493 if (expr->IsVar()) {
7494 IntVar* const expr_var = expr->Var();
7495 expr = CastExpression(expr_var);
7496 }
7497 // This is a dynamic cast to check the type of expr.
7498 // It returns nullptr is expr is not a subclass of SubIntExpr.
7499 SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7500 if (sub_expr != nullptr) {
7501 *left = sub_expr->left();
7502 *right = sub_expr->right();
7503 return true;
7504 }
7505 return false;
7506}
7507
7508bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7509 bool* is_negated) const {
7510 if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7511 *inner_var = expr->Var();
7512 *is_negated = false;
7513 return true;
7514 } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7515 SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7516 if (sub_var != nullptr && sub_var->Constant() == 1 &&
7517 sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7518 *is_negated = true;
7519 *inner_var = sub_var->SubVar();
7520 return true;
7521 }
7522 }
7523 return false;
7524}
7525
7526bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7527 int64_t* coefficient) {
7528 if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7529 TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7530 *coefficient = var->Constant();
7531 *inner_expr = var->SubVar();
7532 return true;
7533 } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7534 TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7535 *coefficient = prod->Constant();
7536 *inner_expr = prod->Expr();
7537 return true;
7538 }
7539 *inner_expr = expr;
7540 *coefficient = 1;
7541 return false;
7542}
7543
7544} // namespace operations_research
IntegerValue y
IntegerValue size
int64_t max
int64_t min
bool Contains(int64_t v) const override
IntVar * IsGreaterOrEqual(int64_t constant) override
IntVar * IsEqual(int64_t constant) override
IsEqual.
void RemoveValue(int64_t v) override
This method removes the value 'v' from the domain of the variable.
IntVar * IsDifferent(int64_t constant) override
SimpleRevFIFO< Demon * > delayed_bound_demons_
void SetMax(int64_t m) override
void SetRange(int64_t mi, int64_t ma) override
This method sets both the min and the max of the expression.
void RemoveInterval(int64_t l, int64_t u) override
void SetMin(int64_t m) override
SimpleRevFIFO< Demon * > bound_demons_
uint64_t Size() const override
This method returns the number of values in the domain of the variable.
static const int kUnboundBooleanVarValue
--— Boolean variable --—
std::string DebugString() const override
void WhenBound(Demon *d) override
IntVar * IsLessOrEqual(int64_t constant) override
virtual Solver::DemonPriority priority() const
---------------— Demon class -------------—
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual int64_t Min() const =0
IntVar * VarWithName(const std::string &name)
-------— IntExpr -------—
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
IntVar(Solver *s)
-------— IntVar -------—
IntVar * Var() override
Creates a variable from the expression.
virtual int VarType() const
------— IntVar ------—
virtual void VisitIntegerVariable(const IntVar *variable, IntExpr *delegate)
void SetRange(int64_t l, int64_t u) override
This method sets both the min and the max of the expression.
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
std::string name() const override
Object naming.
std::string DebugString() const override
virtual std::string name() const
Object naming.
For the time being, Solver is neither MT_SAFE nor MT_HOT.
IntExpr * MakeDifference(IntExpr *left, IntExpr *right)
left - right
void Fail()
Abandon the current branch in the search tree. A backtrack will follow.
@ VAR_PRIORITY
VAR_PRIORITY is between DELAYED_PRIORITY and NORMAL_PRIORITY.
@ OUTSIDE_SEARCH
Before search, after search.
IntVar * MakeIntConst(int64_t val, const std::string &name)
IntConst will create a constant expression.
void AddCastConstraint(CastConstraint *constraint, IntVar *target_var, IntExpr *expr)
int64_t b
Definition table.cc:45
int64_t a
Definition table.cc:44
const std::string name
A name for logging purposes.
const Constraint * ct
int64_t value
IntVar *const expr_
Definition element.cc:88
IntVar * var
const int64_t limit_
const int64_t pow_
ABSL_FLAG(bool, cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
const int64_t cst_
int index
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition map_util.h:94
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition stl_util.h:58
For infeasible and unbounded see Not checked if options check_solutions_if_inf_or_unbounded and the If options first_solution_only is false
problem is infeasible or unbounded (default).
Definition matchers.h:468
std::pair< double, double > Range
A range of values, first is the minimum, second is the maximum.
Definition statistics.h:27
std::function< int64_t(const Model &)> Value(IntegerVariable v)
This checks that the variable is fixed.
Definition integer.h:1975
In SWIG mode, we don't want anything besides these top-level includes.
int64_t SubOverflows(int64_t x, int64_t y)
int64_t UnsafeMostSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
int64_t CapAdd(int64_t x, int64_t y)
void RestoreBoolValue(IntVar *var)
--— Trail --—
int64_t UnsafeLeastSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
int64_t CapSub(int64_t x, int64_t y)
Constraint * SetIsGreaterOrEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
Constraint * SetIsEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
bool AddOverflows(int64_t x, int64_t y)
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
--— Exported Methods for Unit Tests --—
static const uint64_t kAllBits64
Basic bit operations.
Definition bitset.h:37
void LinkVarExpr(Solver *s, IntExpr *expr, IntVar *var)
--— IntExprElement --—
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
int64_t CapProd(int64_t x, int64_t y)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
--— Vector manipulations --—
Definition utilities.cc:829
uint64_t OneRange64(uint64_t s, uint64_t e)
Returns a word with bits from s to e set.
Definition bitset.h:289
uint32_t BitPos64(uint64_t pos)
Bit operators used to manipulates bitsets.
Definition bitset.h:334
uint64_t BitCountRange64(const uint64_t *bitset, uint64_t start, uint64_t end)
Returns the number of bits set in bitset between positions start and end.
uint64_t BitCount64(uint64_t n)
Returns the number of bits set in n.
Definition bitset.h:46
bool IsBitSet64(const uint64_t *const bitset, uint64_t pos)
Returns true if the bit pos is set in bitset.
Definition bitset.h:350
uint64_t OneBit64(int pos)
Returns a word with only bit pos set.
Definition bitset.h:42
uint64_t BitOffset64(uint64_t pos)
Returns the word number corresponding to bit number pos.
Definition bitset.h:338
int64_t PosIntDivDown(int64_t e, int64_t v)
uint64_t BitLength64(uint64_t size)
Returns the number of words needed to store size bits.
Definition bitset.h:342
int LeastSignificantBitPosition64(uint64_t n)
Definition bitset.h:131
void CleanVariableOnFail(IntVar *var)
---------------— Queue class ---------------—
int64_t CapOpp(int64_t v)
Note(user): -kint64min != kint64max, but kint64max == ~kint64min.
int MostSignificantBitPosition64(uint64_t n)
Definition bitset.h:235
int64_t PosIntDivUp(int64_t e, int64_t v)
STL namespace.
false true
Definition numbers.cc:228
trees with all degrees equal w the current value of w
bool Next()
const Variable x
Definition qp_tests.cc:127
int64_t coefficient
int64_t stamp
Definition search.cc:3270
std::optional< int64_t > end
int64_t start
const std::optional< Range > & range
Definition statistics.cc:37