Google OR-Tools v9.12
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
expressions.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cmath>
16#include <cstdint>
17#include <cstdlib>
18#include <limits>
19#include <memory>
20#include <string>
21#include <utility>
22#include <vector>
23
24#include "absl/container/flat_hash_map.h"
25#include "absl/strings/str_cat.h"
26#include "absl/strings/str_format.h"
27#include "absl/strings/string_view.h"
28#include "absl/types/span.h"
34#include "ortools/base/types.h"
37#include "ortools/util/bitset.h"
40
41ABSL_FLAG(bool, cp_disable_expression_optimization, false,
42 "Disable special optimization when creating expressions.");
43ABSL_FLAG(bool, cp_share_int_consts, true,
44 "Share IntConst's with the same value.");
45
46#if defined(_MSC_VER)
47#pragma warning(disable : 4351 4355)
48#endif
49
50namespace operations_research {
51
52// ---------- IntExpr ----------
53
54IntVar* IntExpr::VarWithName(const std::string& name) {
55 IntVar* const var = Var();
56 var->set_name(name);
57 return var;
58}
59
60// ---------- IntVar ----------
61
62IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
63
64IntVar::IntVar(Solver* const s, const std::string& name)
65 : IntExpr(s), index_(s->GetNewIntVarIndex()) {
67}
68
69// ----- Boolean variable -----
70
72
73void BooleanVar::SetMin(int64_t m) {
74 if (m <= 0) return;
75 if (m > 1) solver()->Fail();
76 SetValue(1);
77}
78
79void BooleanVar::SetMax(int64_t m) {
80 if (m >= 1) return;
81 if (m < 0) solver()->Fail();
82 SetValue(0);
83}
84
85void BooleanVar::SetRange(int64_t mi, int64_t ma) {
86 if (mi > 1 || ma < 0 || mi > ma) {
87 solver()->Fail();
88 }
89 if (mi == 1) {
90 SetValue(1);
91 } else if (ma == 0) {
92 SetValue(0);
93 }
94}
95
96void BooleanVar::RemoveValue(int64_t v) {
98 if (v == 0) {
99 SetValue(1);
100 } else if (v == 1) {
101 SetValue(0);
102 }
103 } else if (v == value_) {
104 solver()->Fail();
105 }
106}
107
108void BooleanVar::RemoveInterval(int64_t l, int64_t u) {
109 if (u < l) return;
110 if (l <= 0 && u >= 1) {
111 solver()->Fail();
112 } else if (l == 1) {
113 SetValue(0);
114 } else if (u == 0) {
115 SetValue(1);
116 }
117}
118
122 delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
123 } else {
124 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
125 }
126 }
127}
128
129uint64_t BooleanVar::Size() const {
130 return (1 + (value_ == kUnboundBooleanVarValue));
131}
132
133bool BooleanVar::Contains(int64_t v) const {
134 return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
135}
136
137IntVar* BooleanVar::IsEqual(int64_t constant) {
138 if (constant > 1 || constant < 0) {
139 return solver()->MakeIntConst(0);
140 }
141 if (constant == 1) {
142 return this;
143 } else { // constant == 0.
144 return solver()->MakeDifference(1, this)->Var();
145 }
146}
147
148IntVar* BooleanVar::IsDifferent(int64_t constant) {
149 if (constant > 1 || constant < 0) {
150 return solver()->MakeIntConst(1);
151 }
152 if (constant == 1) {
153 return solver()->MakeDifference(1, this)->Var();
154 } else { // constant == 0.
155 return this;
156 }
157}
158
160 if (constant > 1) {
161 return solver()->MakeIntConst(0);
162 } else if (constant <= 0) {
163 return solver()->MakeIntConst(1);
164 } else {
165 return this;
166 }
167}
168
170 if (constant < 0) {
171 return solver()->MakeIntConst(0);
172 } else if (constant >= 1) {
173 return solver()->MakeIntConst(1);
174 } else {
175 return IsEqual(0);
176 }
177}
178
179std::string BooleanVar::DebugString() const {
180 std::string out;
181 const std::string& var_name = name();
182 if (!var_name.empty()) {
183 out = var_name + "(";
184 } else {
185 out = "BooleanVar(";
186 }
187 switch (value_) {
188 case 0:
189 out += "0";
190 break;
191 case 1:
192 out += "1";
193 break;
195 out += "0 .. 1";
196 break;
197 }
198 out += ")";
199 return out;
200}
201
202namespace {
203// ---------- Subclasses of IntVar ----------
204
205// ----- Domain Int Var: base class for variables -----
206// It Contains bounds and a bitset representation of possible values.
207class DomainIntVar : public IntVar {
208 public:
209 // Utility classes
210 class BitSetIterator : public BaseObject {
211 public:
212 BitSetIterator(uint64_t* const bitset, int64_t omin)
213 : bitset_(bitset),
214 omin_(omin),
215 max_(std::numeric_limits<int64_t>::min()),
216 current_(std::numeric_limits<int64_t>::max()) {}
217
218 ~BitSetIterator() override {}
219
220 void Init(int64_t min, int64_t max) {
221 max_ = max;
222 current_ = min;
223 }
224
225 bool Ok() const { return current_ <= max_; }
226
227 int64_t Value() const { return current_; }
228
229 void Next() {
230 if (++current_ <= max_) {
232 bitset_, current_ - omin_, max_ - omin_) +
233 omin_;
234 }
235 }
236
237 std::string DebugString() const override { return "BitSetIterator"; }
238
239 private:
240 uint64_t* const bitset_;
241 const int64_t omin_;
242 int64_t max_;
243 int64_t current_;
244 };
245
246 class BitSet : public BaseObject {
247 public:
248 explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
249 ~BitSet() override {}
250
251 virtual int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) = 0;
252 virtual int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) = 0;
253 virtual bool Contains(int64_t val) const = 0;
254 virtual bool SetValue(int64_t val) = 0;
255 virtual bool RemoveValue(int64_t val) = 0;
256 virtual uint64_t Size() const = 0;
257 virtual void DelayRemoveValue(int64_t val) = 0;
258 virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
259 virtual void ClearRemovedValues() = 0;
260 virtual std::string pretty_DebugString(int64_t min, int64_t max) const = 0;
261 virtual BitSetIterator* MakeIterator() = 0;
262
263 void InitHoles() {
264 const uint64_t current_stamp = solver_->stamp();
265 if (holes_stamp_ < current_stamp) {
266 holes_.clear();
267 holes_stamp_ = current_stamp;
268 }
269 }
270
271 virtual void ClearHoles() { holes_.clear(); }
272
273 const std::vector<int64_t>& Holes() { return holes_; }
274
275 void AddHole(int64_t value) { holes_.push_back(value); }
276
277 int NumHoles() const {
278 return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
279 }
280
281 protected:
282 Solver* const solver_;
283
284 private:
285 std::vector<int64_t> holes_;
286 uint64_t holes_stamp_;
287 };
288
289 class QueueHandler : public Demon {
290 public:
291 explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
292 ~QueueHandler() override {}
293 void Run(Solver* const s) override {
294 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
295 var_->Process();
296 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
297 }
298 Solver::DemonPriority priority() const override {
300 }
301 std::string DebugString() const override {
302 return absl::StrFormat("Handler(%s)", var_->DebugString());
303 }
304
305 private:
306 DomainIntVar* const var_;
307 };
308
309 // Bounds and Value watchers
310
311 // This class stores the watchers variables attached to values. It is
312 // reversible and it helps maintaining the set of 'active' watchers
313 // (variables not bound to a single value).
314 template <class T>
315 class RevIntPtrMap {
316 public:
317 RevIntPtrMap(Solver* const solver, int64_t rmin, int64_t rmax)
318 : solver_(solver), range_min_(rmin), start_(0) {}
319
320 ~RevIntPtrMap() {}
321
322 bool Empty() const { return start_.Value() == elements_.size(); }
323
324 void SortActive() { std::sort(elements_.begin(), elements_.end()); }
325
326 // Access with value API.
327
328 // Add the pointer to the map attached to the given value.
329 void UnsafeRevInsert(int64_t value, T* elem) {
330 elements_.push_back(std::make_pair(value, elem));
331 if (solver_->state() != Solver::OUTSIDE_SEARCH) {
332 solver_->AddBacktrackAction(
333 [this, value](Solver* s) { Uninsert(value); }, false);
334 }
335 }
336
337 T* FindPtrOrNull(int64_t value, int* position) {
338 for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
339 if (elements_[pos].first == value) {
340 if (position != nullptr) *position = pos;
341 return At(pos).second;
342 }
343 }
344 return nullptr;
345 }
346
347 // Access map through the underlying vector.
348 void RemoveAt(int position) {
349 const int start = start_.Value();
350 DCHECK_GE(position, start);
351 DCHECK_LT(position, elements_.size());
352 if (position > start) {
353 // Swap the current element with the one at the start position, and
354 // increase start.
355 const std::pair<int64_t, T*> copy = elements_[start];
356 elements_[start] = elements_[position];
357 elements_[position] = copy;
358 }
359 start_.Incr(solver_);
360 }
361
362 const std::pair<int64_t, T*>& At(int position) const {
363 DCHECK_GE(position, start_.Value());
364 DCHECK_LT(position, elements_.size());
365 return elements_[position];
366 }
367
368 void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
369
370 int start() const { return start_.Value(); }
371 int end() const { return elements_.size(); }
372 // Number of active elements.
373 int Size() const { return elements_.size() - start_.Value(); }
374
375 // Removes the object permanently from the map.
376 void Uninsert(int64_t value) {
377 for (int pos = 0; pos < elements_.size(); ++pos) {
378 if (elements_[pos].first == value) {
379 DCHECK_GE(pos, start_.Value());
380 const int last = elements_.size() - 1;
381 if (pos != last) { // Swap the current with the last.
382 elements_[pos] = elements_.back();
383 }
384 elements_.pop_back();
385 return;
386 }
387 }
388 LOG(FATAL) << "The element should have been removed";
389 }
390
391 private:
392 Solver* const solver_;
393 const int64_t range_min_;
394 NumericalRev<int> start_;
395 std::vector<std::pair<int64_t, T*>> elements_;
396 };
397
398 // Base class for value watchers
399 class BaseValueWatcher : public Constraint {
400 public:
401 explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
402
403 ~BaseValueWatcher() override {}
404
405 virtual IntVar* GetOrMakeValueWatcher(int64_t value) = 0;
406
407 virtual void SetValueWatcher(IntVar* boolvar, int64_t value) = 0;
408 };
409
410 // This class monitors the domain of the variable and updates the
411 // IsEqual/IsDifferent boolean variables accordingly.
412 class ValueWatcher : public BaseValueWatcher {
413 public:
414 class WatchDemon : public Demon {
415 public:
416 WatchDemon(ValueWatcher* const watcher, int64_t value, IntVar* var)
417 : value_watcher_(watcher), value_(value), var_(var) {}
418 ~WatchDemon() override {}
419
420 void Run(Solver* const solver) override {
421 value_watcher_->ProcessValueWatcher(value_, var_);
422 }
423
424 private:
425 ValueWatcher* const value_watcher_;
426 const int64_t value_;
427 IntVar* const var_;
428 };
429
430 class VarDemon : public Demon {
431 public:
432 explicit VarDemon(ValueWatcher* const watcher)
433 : value_watcher_(watcher) {}
434
435 ~VarDemon() override {}
436
437 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
438
439 private:
440 ValueWatcher* const value_watcher_;
441 };
442
443 ValueWatcher(Solver* const solver, DomainIntVar* const variable)
444 : BaseValueWatcher(solver),
445 variable_(variable),
446 hole_iterator_(variable_->MakeHoleIterator(true)),
447 var_demon_(nullptr),
448 watchers_(solver, variable->Min(), variable->Max()) {}
449
450 ~ValueWatcher() override {}
451
452 IntVar* GetOrMakeValueWatcher(int64_t value) override {
453 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
454 if (watcher != nullptr) return watcher;
455 if (variable_->Contains(value)) {
456 if (variable_->Bound()) {
457 return solver()->MakeIntConst(1);
458 } else {
459 const std::string vname = variable_->HasName()
460 ? variable_->name()
461 : variable_->DebugString();
462 const std::string bname =
463 absl::StrFormat("Watch<%s == %d>", vname, value);
464 IntVar* const boolvar = solver()->MakeBoolVar(bname);
465 watchers_.UnsafeRevInsert(value, boolvar);
466 if (posted_.Switched()) {
467 boolvar->WhenBound(
468 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
469 var_demon_->desinhibit(solver());
470 }
471 return boolvar;
472 }
473 } else {
474 return variable_->solver()->MakeIntConst(0);
475 }
476 }
477
478 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
479 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
480 if (!boolvar->Bound()) {
481 watchers_.UnsafeRevInsert(value, boolvar);
482 if (posted_.Switched() && !boolvar->Bound()) {
483 boolvar->WhenBound(
484 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
485 var_demon_->desinhibit(solver());
486 }
487 }
488 }
489
490 void Post() override {
491 var_demon_ = solver()->RevAlloc(new VarDemon(this));
492 variable_->WhenDomain(var_demon_);
493 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
494 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
495 const int64_t value = w.first;
496 IntVar* const boolvar = w.second;
497 if (!boolvar->Bound() && variable_->Contains(value)) {
498 boolvar->WhenBound(
499 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
500 }
501 }
502 posted_.Switch(solver());
503 }
504
505 void InitialPropagate() override {
506 if (variable_->Bound()) {
507 VariableBound();
508 } else {
509 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
510 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
511 const int64_t value = w.first;
512 IntVar* const boolvar = w.second;
513 if (!variable_->Contains(value)) {
514 boolvar->SetValue(0);
515 watchers_.RemoveAt(pos);
516 } else {
517 if (boolvar->Bound()) {
518 ProcessValueWatcher(value, boolvar);
519 watchers_.RemoveAt(pos);
520 }
521 }
522 }
523 CheckInhibit();
524 }
525 }
526
527 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
528 if (boolvar->Min() == 0) {
529 if (variable_->Size() < 0xFFFFFF) {
530 variable_->RemoveValue(value);
531 } else {
532 // Delay removal.
533 solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
534 }
535 } else {
536 variable_->SetValue(value);
537 }
538 }
539
540 void ProcessVar() {
541 const int kSmallList = 16;
542 if (variable_->Bound()) {
543 VariableBound();
544 } else if (watchers_.Size() <= kSmallList ||
545 variable_->Min() != variable_->OldMin() ||
546 variable_->Max() != variable_->OldMax()) {
547 // Brute force loop for small numbers of watchers, or if the bounds have
548 // changed, which would have required a sort (n log(n)) anyway to take
549 // advantage of.
550 ScanWatchers();
551 CheckInhibit();
552 } else {
553 // If there is no bitset, then there are no holes.
554 // In that case, the two loops above should have performed all
555 // propagation. Otherwise, scan the remaining watchers.
556 BitSet* const bitset = variable_->bitset();
557 if (bitset != nullptr && !watchers_.Empty()) {
558 if (bitset->NumHoles() * 2 < watchers_.Size()) {
559 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
560 int pos = 0;
561 IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
562 if (boolvar != nullptr) {
563 boolvar->SetValue(0);
564 watchers_.RemoveAt(pos);
565 }
566 }
567 } else {
568 ScanWatchers();
569 }
570 }
571 CheckInhibit();
572 }
573 }
574
575 // Optimized case if the variable is bound.
576 void VariableBound() {
577 DCHECK(variable_->Bound());
578 const int64_t value = variable_->Min();
579 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
580 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
581 w.second->SetValue(w.first == value);
582 }
583 watchers_.RemoveAll();
584 var_demon_->inhibit(solver());
585 }
586
587 // Scans all the watchers to check and assign them.
588 void ScanWatchers() {
589 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
590 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
591 if (!variable_->Contains(w.first)) {
592 IntVar* const boolvar = w.second;
593 boolvar->SetValue(0);
594 watchers_.RemoveAt(pos);
595 }
596 }
597 }
598
599 // If the set of active watchers is empty, we can inhibit the demon on the
600 // main variable.
601 void CheckInhibit() {
602 if (watchers_.Empty()) {
603 var_demon_->inhibit(solver());
604 }
605 }
606
607 void Accept(ModelVisitor* const visitor) const override {
608 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
609 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
610 variable_);
611 std::vector<int64_t> all_coefficients;
612 std::vector<IntVar*> all_bool_vars;
613 for (int position = watchers_.start(); position < watchers_.end();
614 ++position) {
615 const std::pair<int64_t, IntVar*>& w = watchers_.At(position);
616 all_coefficients.push_back(w.first);
617 all_bool_vars.push_back(w.second);
618 }
619 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
620 all_bool_vars);
621 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
622 all_coefficients);
623 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
624 }
625
626 std::string DebugString() const override {
627 return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
628 }
629
630 private:
631 DomainIntVar* const variable_;
632 IntVarIterator* const hole_iterator_;
633 RevSwitch posted_;
634 Demon* var_demon_;
635 RevIntPtrMap<IntVar> watchers_;
636 };
637
638 // Optimized case for small maps.
639 class DenseValueWatcher : public BaseValueWatcher {
640 public:
641 class WatchDemon : public Demon {
642 public:
643 WatchDemon(DenseValueWatcher* const watcher, int64_t value, IntVar* var)
644 : value_watcher_(watcher), value_(value), var_(var) {}
645 ~WatchDemon() override {}
646
647 void Run(Solver* const solver) override {
648 value_watcher_->ProcessValueWatcher(value_, var_);
649 }
650
651 private:
652 DenseValueWatcher* const value_watcher_;
653 const int64_t value_;
654 IntVar* const var_;
655 };
656
657 class VarDemon : public Demon {
658 public:
659 explicit VarDemon(DenseValueWatcher* const watcher)
660 : value_watcher_(watcher) {}
661
662 ~VarDemon() override {}
663
664 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
665
666 private:
667 DenseValueWatcher* const value_watcher_;
668 };
669
670 DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
671 : BaseValueWatcher(solver),
672 variable_(variable),
673 hole_iterator_(variable_->MakeHoleIterator(true)),
674 var_demon_(nullptr),
675 offset_(variable->Min()),
676 watchers_(variable->Max() - variable->Min() + 1, nullptr),
677 active_watchers_(0) {}
678
679 ~DenseValueWatcher() override {}
680
681 IntVar* GetOrMakeValueWatcher(int64_t value) override {
682 const int64_t var_max = offset_ + watchers_.size() - 1; // Bad cast.
683 if (value < offset_ || value > var_max) {
684 return solver()->MakeIntConst(0);
685 }
686 const int index = value - offset_;
687 IntVar* const watcher = watchers_[index];
688 if (watcher != nullptr) return watcher;
689 if (variable_->Contains(value)) {
690 if (variable_->Bound()) {
691 return solver()->MakeIntConst(1);
692 } else {
693 const std::string vname = variable_->HasName()
694 ? variable_->name()
695 : variable_->DebugString();
696 const std::string bname =
697 absl::StrFormat("Watch<%s == %d>", vname, value);
698 IntVar* const boolvar = solver()->MakeBoolVar(bname);
699 RevInsert(index, boolvar);
700 if (posted_.Switched()) {
701 boolvar->WhenBound(
702 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
703 var_demon_->desinhibit(solver());
704 }
705 return boolvar;
706 }
707 } else {
708 return variable_->solver()->MakeIntConst(0);
709 }
710 }
711
712 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
713 const int index = value - offset_;
714 CHECK(watchers_[index] == nullptr);
715 if (!boolvar->Bound()) {
716 RevInsert(index, boolvar);
717 if (posted_.Switched() && !boolvar->Bound()) {
718 boolvar->WhenBound(
719 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
720 var_demon_->desinhibit(solver());
721 }
722 }
723 }
724
725 void Post() override {
726 var_demon_ = solver()->RevAlloc(new VarDemon(this));
727 variable_->WhenDomain(var_demon_);
728 for (int pos = 0; pos < watchers_.size(); ++pos) {
729 const int64_t value = pos + offset_;
730 IntVar* const boolvar = watchers_[pos];
731 if (boolvar != nullptr && !boolvar->Bound() &&
732 variable_->Contains(value)) {
733 boolvar->WhenBound(
734 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
735 }
736 }
737 posted_.Switch(solver());
738 }
739
740 void InitialPropagate() override {
741 if (variable_->Bound()) {
742 VariableBound();
743 } else {
744 for (int pos = 0; pos < watchers_.size(); ++pos) {
745 IntVar* const boolvar = watchers_[pos];
746 if (boolvar == nullptr) continue;
747 const int64_t value = pos + offset_;
748 if (!variable_->Contains(value)) {
749 boolvar->SetValue(0);
750 RevRemove(pos);
751 } else if (boolvar->Bound()) {
752 ProcessValueWatcher(value, boolvar);
753 RevRemove(pos);
754 }
755 }
756 if (active_watchers_.Value() == 0) {
757 var_demon_->inhibit(solver());
758 }
759 }
760 }
761
762 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
763 if (boolvar->Min() == 0) {
764 variable_->RemoveValue(value);
765 } else {
766 variable_->SetValue(value);
767 }
768 }
769
770 void ProcessVar() {
771 if (variable_->Bound()) {
772 VariableBound();
773 } else {
774 // Brute force loop for small numbers of watchers.
775 ScanWatchers();
776 if (active_watchers_.Value() == 0) {
777 var_demon_->inhibit(solver());
778 }
779 }
780 }
781
782 // Optimized case if the variable is bound.
783 void VariableBound() {
784 DCHECK(variable_->Bound());
785 const int64_t value = variable_->Min();
786 for (int pos = 0; pos < watchers_.size(); ++pos) {
787 IntVar* const boolvar = watchers_[pos];
788 if (boolvar != nullptr) {
789 boolvar->SetValue(pos + offset_ == value);
790 RevRemove(pos);
791 }
792 }
793 var_demon_->inhibit(solver());
794 }
795
796 // Scans all the watchers to check and assign them.
797 void ScanWatchers() {
798 const int64_t old_min_index = variable_->OldMin() - offset_;
799 const int64_t old_max_index = variable_->OldMax() - offset_;
800 const int64_t min_index = variable_->Min() - offset_;
801 const int64_t max_index = variable_->Max() - offset_;
802 for (int pos = old_min_index; pos < min_index; ++pos) {
803 IntVar* const boolvar = watchers_[pos];
804 if (boolvar != nullptr) {
805 boolvar->SetValue(0);
806 RevRemove(pos);
807 }
808 }
809 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
810 IntVar* const boolvar = watchers_[pos];
811 if (boolvar != nullptr) {
812 boolvar->SetValue(0);
813 RevRemove(pos);
814 }
815 }
816 BitSet* const bitset = variable_->bitset();
817 if (bitset != nullptr) {
818 if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
819 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
820 IntVar* const boolvar = watchers_[hole - offset_];
821 if (boolvar != nullptr) {
822 boolvar->SetValue(0);
823 RevRemove(hole - offset_);
824 }
825 }
826 } else {
827 for (int pos = min_index + 1; pos < max_index; ++pos) {
828 IntVar* const boolvar = watchers_[pos];
829 if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
830 boolvar->SetValue(0);
831 RevRemove(pos);
832 }
833 }
834 }
835 }
836 }
837
838 void RevRemove(int pos) {
839 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
840 watchers_[pos] = nullptr;
841 active_watchers_.Decr(solver());
842 }
843
844 void RevInsert(int pos, IntVar* boolvar) {
845 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
846 watchers_[pos] = boolvar;
847 active_watchers_.Incr(solver());
848 }
849
850 void Accept(ModelVisitor* const visitor) const override {
851 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
852 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
853 variable_);
854 std::vector<int64_t> all_coefficients;
855 std::vector<IntVar*> all_bool_vars;
856 for (int position = 0; position < watchers_.size(); ++position) {
857 if (watchers_[position] != nullptr) {
858 all_coefficients.push_back(position + offset_);
859 all_bool_vars.push_back(watchers_[position]);
860 }
861 }
862 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
863 all_bool_vars);
864 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
865 all_coefficients);
866 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
867 }
868
869 std::string DebugString() const override {
870 return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
871 }
872
873 private:
874 DomainIntVar* const variable_;
875 IntVarIterator* const hole_iterator_;
876 RevSwitch posted_;
877 Demon* var_demon_;
878 const int64_t offset_;
879 std::vector<IntVar*> watchers_;
880 NumericalRev<int> active_watchers_;
881 };
882
883 class BaseUpperBoundWatcher : public Constraint {
884 public:
885 explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
886
887 ~BaseUpperBoundWatcher() override {}
888
889 virtual IntVar* GetOrMakeUpperBoundWatcher(int64_t value) = 0;
890
891 virtual void SetUpperBoundWatcher(IntVar* boolvar, int64_t value) = 0;
892 };
893
894 // This class watches the bounds of the variable and updates the
895 // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
896 // accordingly.
897 class UpperBoundWatcher : public BaseUpperBoundWatcher {
898 public:
899 class WatchDemon : public Demon {
900 public:
901 WatchDemon(UpperBoundWatcher* const watcher, int64_t index,
902 IntVar* const var)
903 : value_watcher_(watcher), index_(index), var_(var) {}
904 ~WatchDemon() override {}
905
906 void Run(Solver* const solver) override {
907 value_watcher_->ProcessUpperBoundWatcher(index_, var_);
908 }
909
910 private:
911 UpperBoundWatcher* const value_watcher_;
912 const int64_t index_;
913 IntVar* const var_;
914 };
915
916 class VarDemon : public Demon {
917 public:
918 explicit VarDemon(UpperBoundWatcher* const watcher)
919 : value_watcher_(watcher) {}
920 ~VarDemon() override {}
921
922 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
923
924 private:
925 UpperBoundWatcher* const value_watcher_;
926 };
927
928 UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
929 : BaseUpperBoundWatcher(solver),
930 variable_(variable),
931 var_demon_(nullptr),
932 watchers_(solver, variable->Min(), variable->Max()),
933 start_(0),
934 end_(0),
935 sorted_(false) {}
936
937 ~UpperBoundWatcher() override {}
938
939 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
940 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
941 if (watcher != nullptr) {
942 return watcher;
943 }
944 if (variable_->Max() >= value) {
945 if (variable_->Min() >= value) {
946 return solver()->MakeIntConst(1);
947 } else {
948 const std::string vname = variable_->HasName()
949 ? variable_->name()
950 : variable_->DebugString();
951 const std::string bname =
952 absl::StrFormat("Watch<%s >= %d>", vname, value);
953 IntVar* const boolvar = solver()->MakeBoolVar(bname);
954 watchers_.UnsafeRevInsert(value, boolvar);
955 if (posted_.Switched()) {
956 boolvar->WhenBound(
957 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
958 var_demon_->desinhibit(solver());
959 sorted_ = false;
960 }
961 return boolvar;
962 }
963 } else {
964 return variable_->solver()->MakeIntConst(0);
965 }
966 }
967
968 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
969 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
970 watchers_.UnsafeRevInsert(value, boolvar);
971 if (posted_.Switched() && !boolvar->Bound()) {
972 boolvar->WhenBound(
973 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
974 var_demon_->desinhibit(solver());
975 sorted_ = false;
976 }
977 }
978
979 void Post() override {
980 const int kTooSmallToSort = 8;
981 var_demon_ = solver()->RevAlloc(new VarDemon(this));
982 variable_->WhenRange(var_demon_);
983
984 if (watchers_.Size() > kTooSmallToSort) {
985 watchers_.SortActive();
986 sorted_ = true;
987 start_.SetValue(solver(), watchers_.start());
988 end_.SetValue(solver(), watchers_.end() - 1);
989 }
990
991 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
992 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
993 IntVar* const boolvar = w.second;
994 const int64_t value = w.first;
995 if (!boolvar->Bound() && value > variable_->Min() &&
996 value <= variable_->Max()) {
997 boolvar->WhenBound(
998 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
999 }
1000 }
1001 posted_.Switch(solver());
1002 }
1003
1004 void InitialPropagate() override {
1005 const int64_t var_min = variable_->Min();
1006 const int64_t var_max = variable_->Max();
1007 if (sorted_) {
1008 while (start_.Value() <= end_.Value()) {
1009 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1010 if (w.first <= var_min) {
1011 w.second->SetValue(1);
1012 start_.Incr(solver());
1013 } else {
1014 break;
1015 }
1016 }
1017 while (end_.Value() >= start_.Value()) {
1018 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1019 if (w.first > var_max) {
1020 w.second->SetValue(0);
1021 end_.Decr(solver());
1022 } else {
1023 break;
1024 }
1025 }
1026 for (int i = start_.Value(); i <= end_.Value(); ++i) {
1027 const std::pair<int64_t, IntVar*>& w = watchers_.At(i);
1028 if (w.second->Bound()) {
1029 ProcessUpperBoundWatcher(w.first, w.second);
1030 }
1031 }
1032 if (start_.Value() > end_.Value()) {
1033 var_demon_->inhibit(solver());
1034 }
1035 } else {
1036 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1037 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1038 const int64_t value = w.first;
1039 IntVar* const boolvar = w.second;
1040
1041 if (value <= var_min) {
1042 boolvar->SetValue(1);
1043 watchers_.RemoveAt(pos);
1044 } else if (value > var_max) {
1045 boolvar->SetValue(0);
1046 watchers_.RemoveAt(pos);
1047 } else if (boolvar->Bound()) {
1048 ProcessUpperBoundWatcher(value, boolvar);
1049 watchers_.RemoveAt(pos);
1050 }
1051 }
1052 }
1053 }
1054
1055 void Accept(ModelVisitor* const visitor) const override {
1056 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1057 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1058 variable_);
1059 std::vector<int64_t> all_coefficients;
1060 std::vector<IntVar*> all_bool_vars;
1061 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1062 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1063 all_coefficients.push_back(w.first);
1064 all_bool_vars.push_back(w.second);
1065 }
1066 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1067 all_bool_vars);
1068 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1069 all_coefficients);
1070 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1071 }
1072
1073 std::string DebugString() const override {
1074 return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1075 }
1076
1077 private:
1078 void ProcessUpperBoundWatcher(int64_t value, IntVar* const boolvar) {
1079 if (boolvar->Min() == 0) {
1080 variable_->SetMax(value - 1);
1081 } else {
1082 variable_->SetMin(value);
1083 }
1084 }
1085
1086 void ProcessVar() {
1087 const int64_t var_min = variable_->Min();
1088 const int64_t var_max = variable_->Max();
1089 if (sorted_) {
1090 while (start_.Value() <= end_.Value()) {
1091 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1092 if (w.first <= var_min) {
1093 w.second->SetValue(1);
1094 start_.Incr(solver());
1095 } else {
1096 break;
1097 }
1098 }
1099 while (end_.Value() >= start_.Value()) {
1100 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1101 if (w.first > var_max) {
1102 w.second->SetValue(0);
1103 end_.Decr(solver());
1104 } else {
1105 break;
1106 }
1107 }
1108 if (start_.Value() > end_.Value()) {
1109 var_demon_->inhibit(solver());
1110 }
1111 } else {
1112 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1113 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1114 const int64_t value = w.first;
1115 IntVar* const boolvar = w.second;
1116
1117 if (value <= var_min) {
1118 boolvar->SetValue(1);
1119 watchers_.RemoveAt(pos);
1120 } else if (value > var_max) {
1121 boolvar->SetValue(0);
1122 watchers_.RemoveAt(pos);
1123 }
1124 }
1125 if (watchers_.Empty()) {
1126 var_demon_->inhibit(solver());
1127 }
1128 }
1129 }
1130
1131 DomainIntVar* const variable_;
1132 RevSwitch posted_;
1133 Demon* var_demon_;
1134 RevIntPtrMap<IntVar> watchers_;
1135 NumericalRev<int> start_;
1136 NumericalRev<int> end_;
1137 bool sorted_;
1138 };
1139
1140 // Optimized case for small maps.
1141 class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1142 public:
1143 class WatchDemon : public Demon {
1144 public:
1145 WatchDemon(DenseUpperBoundWatcher* const watcher, int64_t value,
1146 IntVar* var)
1147 : value_watcher_(watcher), value_(value), var_(var) {}
1148 ~WatchDemon() override {}
1149
1150 void Run(Solver* const solver) override {
1151 value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1152 }
1153
1154 private:
1155 DenseUpperBoundWatcher* const value_watcher_;
1156 const int64_t value_;
1157 IntVar* const var_;
1158 };
1159
1160 class VarDemon : public Demon {
1161 public:
1162 explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1163 : value_watcher_(watcher) {}
1164
1165 ~VarDemon() override {}
1166
1167 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1168
1169 private:
1170 DenseUpperBoundWatcher* const value_watcher_;
1171 };
1172
1173 DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1174 : BaseUpperBoundWatcher(solver),
1175 variable_(variable),
1176 var_demon_(nullptr),
1177 offset_(variable->Min()),
1178 watchers_(variable->Max() - variable->Min() + 1, nullptr),
1179 active_watchers_(0) {}
1180
1181 ~DenseUpperBoundWatcher() override {}
1182
1183 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
1184 if (variable_->Max() >= value) {
1185 if (variable_->Min() >= value) {
1186 return solver()->MakeIntConst(1);
1187 } else {
1188 const std::string vname = variable_->HasName()
1189 ? variable_->name()
1190 : variable_->DebugString();
1191 const std::string bname =
1192 absl::StrFormat("Watch<%s >= %d>", vname, value);
1193 IntVar* const boolvar = solver()->MakeBoolVar(bname);
1194 RevInsert(value - offset_, boolvar);
1195 if (posted_.Switched()) {
1196 boolvar->WhenBound(
1197 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1198 var_demon_->desinhibit(solver());
1199 }
1200 return boolvar;
1201 }
1202 } else {
1203 return variable_->solver()->MakeIntConst(0);
1204 }
1205 }
1206
1207 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
1208 const int index = value - offset_;
1209 CHECK(watchers_[index] == nullptr);
1210 if (!boolvar->Bound()) {
1211 RevInsert(index, boolvar);
1212 if (posted_.Switched() && !boolvar->Bound()) {
1213 boolvar->WhenBound(
1214 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1215 var_demon_->desinhibit(solver());
1216 }
1217 }
1218 }
1219
1220 void Post() override {
1221 var_demon_ = solver()->RevAlloc(new VarDemon(this));
1222 variable_->WhenRange(var_demon_);
1223 for (int pos = 0; pos < watchers_.size(); ++pos) {
1224 const int64_t value = pos + offset_;
1225 IntVar* const boolvar = watchers_[pos];
1226 if (boolvar != nullptr && !boolvar->Bound() &&
1227 value > variable_->Min() && value <= variable_->Max()) {
1228 boolvar->WhenBound(
1229 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1230 }
1231 }
1232 posted_.Switch(solver());
1233 }
1234
1235 void InitialPropagate() override {
1236 for (int pos = 0; pos < watchers_.size(); ++pos) {
1237 IntVar* const boolvar = watchers_[pos];
1238 if (boolvar == nullptr) continue;
1239 const int64_t value = pos + offset_;
1240 if (value <= variable_->Min()) {
1241 boolvar->SetValue(1);
1242 RevRemove(pos);
1243 } else if (value > variable_->Max()) {
1244 boolvar->SetValue(0);
1245 RevRemove(pos);
1246 } else if (boolvar->Bound()) {
1247 ProcessUpperBoundWatcher(value, boolvar);
1248 RevRemove(pos);
1249 }
1250 }
1251 if (active_watchers_.Value() == 0) {
1252 var_demon_->inhibit(solver());
1253 }
1254 }
1255
1256 void ProcessUpperBoundWatcher(int64_t value, IntVar* boolvar) {
1257 if (boolvar->Min() == 0) {
1258 variable_->SetMax(value - 1);
1259 } else {
1260 variable_->SetMin(value);
1261 }
1262 }
1263
1264 void ProcessVar() {
1265 const int64_t old_min_index = variable_->OldMin() - offset_;
1266 const int64_t old_max_index = variable_->OldMax() - offset_;
1267 const int64_t min_index = variable_->Min() - offset_;
1268 const int64_t max_index = variable_->Max() - offset_;
1269 for (int pos = old_min_index; pos <= min_index; ++pos) {
1270 IntVar* const boolvar = watchers_[pos];
1271 if (boolvar != nullptr) {
1272 boolvar->SetValue(1);
1273 RevRemove(pos);
1274 }
1275 }
1276
1277 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1278 IntVar* const boolvar = watchers_[pos];
1279 if (boolvar != nullptr) {
1280 boolvar->SetValue(0);
1281 RevRemove(pos);
1282 }
1283 }
1284 if (active_watchers_.Value() == 0) {
1285 var_demon_->inhibit(solver());
1286 }
1287 }
1288
1289 void RevRemove(int pos) {
1290 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1291 watchers_[pos] = nullptr;
1292 active_watchers_.Decr(solver());
1293 }
1294
1295 void RevInsert(int pos, IntVar* boolvar) {
1296 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1297 watchers_[pos] = boolvar;
1298 active_watchers_.Incr(solver());
1299 }
1300
1301 void Accept(ModelVisitor* const visitor) const override {
1302 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1303 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1304 variable_);
1305 std::vector<int64_t> all_coefficients;
1306 std::vector<IntVar*> all_bool_vars;
1307 for (int position = 0; position < watchers_.size(); ++position) {
1308 if (watchers_[position] != nullptr) {
1309 all_coefficients.push_back(position + offset_);
1310 all_bool_vars.push_back(watchers_[position]);
1311 }
1312 }
1313 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1314 all_bool_vars);
1315 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1316 all_coefficients);
1317 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1318 }
1319
1320 std::string DebugString() const override {
1321 return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1322 variable_->DebugString());
1323 }
1324
1325 private:
1326 DomainIntVar* const variable_;
1327 RevSwitch posted_;
1328 Demon* var_demon_;
1329 const int64_t offset_;
1330 std::vector<IntVar*> watchers_;
1331 NumericalRev<int> active_watchers_;
1332 };
1333
1334 // ----- Main Class -----
1335 DomainIntVar(Solver* s, int64_t vmin, int64_t vmax, const std::string& name);
1336 DomainIntVar(Solver* s, absl::Span<const int64_t> sorted_values,
1337 const std::string& name);
1338 ~DomainIntVar() override;
1339
1340 int64_t Min() const override { return min_.Value(); }
1341 void SetMin(int64_t m) override;
1342 int64_t Max() const override { return max_.Value(); }
1343 void SetMax(int64_t m) override;
1344 void SetRange(int64_t mi, int64_t ma) override;
1345 void SetValue(int64_t v) override;
1346 bool Bound() const override { return (min_.Value() == max_.Value()); }
1347 int64_t Value() const override {
1348 CHECK_EQ(min_.Value(), max_.Value())
1349 << " variable " << DebugString() << " is not bound.";
1350 return min_.Value();
1351 }
1352 void RemoveValue(int64_t v) override;
1353 void RemoveInterval(int64_t l, int64_t u) override;
1354 void CreateBits();
1355 void WhenBound(Demon* d) override {
1356 if (min_.Value() != max_.Value()) {
1357 if (d->priority() == Solver::DELAYED_PRIORITY) {
1358 delayed_bound_demons_.PushIfNotTop(solver(),
1359 solver()->RegisterDemon(d));
1360 } else {
1361 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1362 }
1363 }
1364 }
1365 void WhenRange(Demon* d) override {
1366 if (min_.Value() != max_.Value()) {
1367 if (d->priority() == Solver::DELAYED_PRIORITY) {
1368 delayed_range_demons_.PushIfNotTop(solver(),
1369 solver()->RegisterDemon(d));
1370 } else {
1371 range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1372 }
1373 }
1374 }
1375 void WhenDomain(Demon* d) override {
1376 if (min_.Value() != max_.Value()) {
1377 if (d->priority() == Solver::DELAYED_PRIORITY) {
1378 delayed_domain_demons_.PushIfNotTop(solver(),
1379 solver()->RegisterDemon(d));
1380 } else {
1381 domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1382 }
1383 }
1384 }
1385
1386 IntVar* IsEqual(int64_t constant) override {
1387 Solver* const s = solver();
1388 if (constant == min_.Value() && value_watcher_ == nullptr) {
1389 return s->MakeIsLessOrEqualCstVar(this, constant);
1390 }
1391 if (constant == max_.Value() && value_watcher_ == nullptr) {
1392 return s->MakeIsGreaterOrEqualCstVar(this, constant);
1393 }
1394 if (!Contains(constant)) {
1395 return s->MakeIntConst(int64_t{0});
1396 }
1397 if (Bound() && min_.Value() == constant) {
1398 return s->MakeIntConst(int64_t{1});
1399 }
1400 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1401 this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1402 if (cache != nullptr) {
1403 return cache->Var();
1404 } else {
1405 if (value_watcher_ == nullptr) {
1406 if (CapSub(Max(), Min()) <= 256) {
1407 solver()->SaveAndSetValue(
1408 reinterpret_cast<void**>(&value_watcher_),
1409 reinterpret_cast<void*>(
1410 solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1411
1412 } else {
1413 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1414 reinterpret_cast<void*>(solver()->RevAlloc(
1415 new ValueWatcher(solver(), this))));
1416 }
1417 solver()->AddConstraint(value_watcher_);
1418 }
1419 IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1420 s->Cache()->InsertExprConstantExpression(
1421 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1422 return boolvar;
1423 }
1424 }
1425
1426 Constraint* SetIsEqual(absl::Span<const int64_t> values,
1427 const std::vector<IntVar*>& vars) {
1428 if (value_watcher_ == nullptr) {
1429 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1430 reinterpret_cast<void*>(solver()->RevAlloc(
1431 new ValueWatcher(solver(), this))));
1432 for (int i = 0; i < vars.size(); ++i) {
1433 value_watcher_->SetValueWatcher(vars[i], values[i]);
1434 }
1435 }
1436 return value_watcher_;
1437 }
1438
1439 IntVar* IsDifferent(int64_t constant) override {
1440 Solver* const s = solver();
1441 if (constant == min_.Value() && value_watcher_ == nullptr) {
1442 return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1443 }
1444 if (constant == max_.Value() && value_watcher_ == nullptr) {
1445 return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1446 }
1447 if (!Contains(constant)) {
1448 return s->MakeIntConst(int64_t{1});
1449 }
1450 if (Bound() && min_.Value() == constant) {
1451 return s->MakeIntConst(int64_t{0});
1452 }
1453 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1455 if (cache != nullptr) {
1456 return cache->Var();
1457 } else {
1458 IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1459 s->Cache()->InsertExprConstantExpression(
1460 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1461 return boolvar;
1462 }
1463 }
1464
1465 IntVar* IsGreaterOrEqual(int64_t constant) override {
1466 Solver* const s = solver();
1467 if (max_.Value() < constant) {
1468 return s->MakeIntConst(int64_t{0});
1469 }
1470 if (min_.Value() >= constant) {
1471 return s->MakeIntConst(int64_t{1});
1472 }
1473 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1475 if (cache != nullptr) {
1476 return cache->Var();
1477 } else {
1478 if (bound_watcher_ == nullptr) {
1479 if (CapSub(Max(), Min()) <= 256) {
1480 solver()->SaveAndSetValue(
1481 reinterpret_cast<void**>(&bound_watcher_),
1482 reinterpret_cast<void*>(solver()->RevAlloc(
1483 new DenseUpperBoundWatcher(solver(), this))));
1484 solver()->AddConstraint(bound_watcher_);
1485 } else {
1486 solver()->SaveAndSetValue(
1487 reinterpret_cast<void**>(&bound_watcher_),
1488 reinterpret_cast<void*>(
1489 solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1490 solver()->AddConstraint(bound_watcher_);
1491 }
1492 }
1493 IntVar* const boolvar =
1494 bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1495 s->Cache()->InsertExprConstantExpression(
1496 boolvar, this, constant,
1498 return boolvar;
1499 }
1500 }
1501
1502 Constraint* SetIsGreaterOrEqual(absl::Span<const int64_t> values,
1503 const std::vector<IntVar*>& vars) {
1504 if (bound_watcher_ == nullptr) {
1505 if (CapSub(Max(), Min()) <= 256) {
1506 solver()->SaveAndSetValue(
1507 reinterpret_cast<void**>(&bound_watcher_),
1508 reinterpret_cast<void*>(solver()->RevAlloc(
1509 new DenseUpperBoundWatcher(solver(), this))));
1510 solver()->AddConstraint(bound_watcher_);
1511 } else {
1512 solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1513 reinterpret_cast<void*>(solver()->RevAlloc(
1514 new UpperBoundWatcher(solver(), this))));
1515 solver()->AddConstraint(bound_watcher_);
1516 }
1517 for (int i = 0; i < values.size(); ++i) {
1518 bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1519 }
1520 }
1521 return bound_watcher_;
1522 }
1523
1524 IntVar* IsLessOrEqual(int64_t constant) override {
1525 Solver* const s = solver();
1526 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1528 if (cache != nullptr) {
1529 return cache->Var();
1530 } else {
1531 IntVar* const boolvar =
1532 s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1533 s->Cache()->InsertExprConstantExpression(
1534 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1535 return boolvar;
1536 }
1537 }
1538
1539 void Process();
1540 void Push();
1541 void CleanInProcess();
1542 uint64_t Size() const override {
1543 if (bits_ != nullptr) return bits_->Size();
1544 return (static_cast<uint64_t>(max_.Value()) -
1545 static_cast<uint64_t>(min_.Value()) + 1);
1546 }
1547 bool Contains(int64_t v) const override {
1548 if (v < min_.Value() || v > max_.Value()) return false;
1549 return (bits_ == nullptr ? true : bits_->Contains(v));
1550 }
1551 IntVarIterator* MakeHoleIterator(bool reversible) const override;
1552 IntVarIterator* MakeDomainIterator(bool reversible) const override;
1553 int64_t OldMin() const override { return std::min(old_min_, min_.Value()); }
1554 int64_t OldMax() const override { return std::max(old_max_, max_.Value()); }
1555
1556 std::string DebugString() const override;
1557 BitSet* bitset() const { return bits_; }
1558 int VarType() const override { return DOMAIN_INT_VAR; }
1559 std::string BaseName() const override { return "IntegerVar"; }
1560
1561 friend class PlusCstDomainIntVar;
1562 friend class LinkExprAndDomainIntVar;
1563
1564 private:
1565 void CheckOldMin() {
1566 if (old_min_ > min_.Value()) {
1567 old_min_ = min_.Value();
1568 }
1569 }
1570 void CheckOldMax() {
1571 if (old_max_ < max_.Value()) {
1572 old_max_ = max_.Value();
1573 }
1574 }
1575 Rev<int64_t> min_;
1576 Rev<int64_t> max_;
1577 int64_t old_min_;
1578 int64_t old_max_;
1579 int64_t new_min_;
1580 int64_t new_max_;
1581 SimpleRevFIFO<Demon*> bound_demons_;
1582 SimpleRevFIFO<Demon*> range_demons_;
1583 SimpleRevFIFO<Demon*> domain_demons_;
1584 SimpleRevFIFO<Demon*> delayed_bound_demons_;
1585 SimpleRevFIFO<Demon*> delayed_range_demons_;
1586 SimpleRevFIFO<Demon*> delayed_domain_demons_;
1587 QueueHandler handler_;
1588 bool in_process_;
1589 BitSet* bits_;
1590 BaseValueWatcher* value_watcher_;
1591 BaseUpperBoundWatcher* bound_watcher_;
1592};
1593
1594// ----- BitSet -----
1595
1596// Return whether an integer interval [a..b] (inclusive) contains at most
1597// K values, i.e. b - a < K, in a way that's robust to overflows.
1598// For performance reasons, in opt mode it doesn't check that [a, b] is a
1599// valid interval, nor that K is nonnegative.
1600inline bool ClosedIntervalNoLargerThan(int64_t a, int64_t b, int64_t K) {
1601 DCHECK_LE(a, b);
1602 DCHECK_GE(K, 0);
1603 if (a > 0) {
1604 return a > b - K;
1605 } else {
1606 return a + K > b;
1607 }
1608}
1609
1610class SimpleBitSet : public DomainIntVar::BitSet {
1611 public:
1612 SimpleBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1613 : BitSet(s),
1614 bits_(nullptr),
1615 stamps_(nullptr),
1616 omin_(vmin),
1617 omax_(vmax),
1618 size_(vmax - vmin + 1),
1619 bsize_(BitLength64(size_.Value())) {
1620 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1621 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1622 bits_ = new uint64_t[bsize_];
1623 stamps_ = new uint64_t[bsize_];
1624 for (int i = 0; i < bsize_; ++i) {
1625 const int bs =
1626 (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1627 bits_[i] = kAllBits64 >> bs;
1628 stamps_[i] = s->stamp() - 1;
1629 }
1630 }
1631
1632 SimpleBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1633 int64_t vmin, int64_t vmax)
1634 : BitSet(s),
1635 bits_(nullptr),
1636 stamps_(nullptr),
1637 omin_(vmin),
1638 omax_(vmax),
1639 size_(sorted_values.size()),
1640 bsize_(BitLength64(vmax - vmin + 1)) {
1641 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1642 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1643 bits_ = new uint64_t[bsize_];
1644 stamps_ = new uint64_t[bsize_];
1645 for (int i = 0; i < bsize_; ++i) {
1646 bits_[i] = uint64_t{0};
1647 stamps_[i] = s->stamp() - 1;
1648 }
1649 for (int i = 0; i < sorted_values.size(); ++i) {
1650 const int64_t val = sorted_values[i];
1651 DCHECK(!bit(val));
1652 const int offset = BitOffset64(val - omin_);
1653 const int pos = BitPos64(val - omin_);
1654 bits_[offset] |= OneBit64(pos);
1655 }
1656 }
1657
1658 ~SimpleBitSet() override {
1659 delete[] bits_;
1660 delete[] stamps_;
1661 }
1662
1663 bool bit(int64_t val) const { return IsBitSet64(bits_, val - omin_); }
1664
1665 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1666 DCHECK_GE(nmin, cmin);
1667 DCHECK_LE(nmin, cmax);
1668 DCHECK_LE(cmin, cmax);
1669 DCHECK_GE(cmin, omin_);
1670 DCHECK_LE(cmax, omax_);
1671 const int64_t new_min =
1672 UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1673 omin_;
1674 const uint64_t removed_bits =
1675 BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1676 size_.Add(solver_, -removed_bits);
1677 return new_min;
1678 }
1679
1680 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1681 DCHECK_GE(nmax, cmin);
1682 DCHECK_LE(nmax, cmax);
1683 DCHECK_LE(cmin, cmax);
1684 DCHECK_GE(cmin, omin_);
1685 DCHECK_LE(cmax, omax_);
1686 const int64_t new_max =
1687 UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1688 omin_;
1689 const uint64_t removed_bits =
1690 BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1691 size_.Add(solver_, -removed_bits);
1692 return new_max;
1693 }
1694
1695 bool SetValue(int64_t val) override {
1696 DCHECK_GE(val, omin_);
1697 DCHECK_LE(val, omax_);
1698 if (bit(val)) {
1699 size_.SetValue(solver_, 1);
1700 return true;
1701 }
1702 return false;
1703 }
1704
1705 bool Contains(int64_t val) const override {
1706 DCHECK_GE(val, omin_);
1707 DCHECK_LE(val, omax_);
1708 return bit(val);
1709 }
1710
1711 bool RemoveValue(int64_t val) override {
1712 if (val < omin_ || val > omax_ || !bit(val)) {
1713 return false;
1714 }
1715 // Bitset.
1716 const int64_t val_offset = val - omin_;
1717 const int offset = BitOffset64(val_offset);
1718 const uint64_t current_stamp = solver_->stamp();
1719 if (stamps_[offset] < current_stamp) {
1720 stamps_[offset] = current_stamp;
1721 solver_->SaveValue(&bits_[offset]);
1722 }
1723 const int pos = BitPos64(val_offset);
1724 bits_[offset] &= ~OneBit64(pos);
1725 // Size.
1726 size_.Decr(solver_);
1727 // Holes.
1728 InitHoles();
1729 AddHole(val);
1730 return true;
1731 }
1732 uint64_t Size() const override { return size_.Value(); }
1733
1734 std::string DebugString() const override {
1735 std::string out;
1736 absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1737 for (int i = 0; i < bsize_; ++i) {
1738 absl::StrAppendFormat(&out, "%x", bits_[i]);
1739 }
1740 out += ")";
1741 return out;
1742 }
1743
1744 void DelayRemoveValue(int64_t val) override { removed_.push_back(val); }
1745
1746 void ApplyRemovedValues(DomainIntVar* var) override {
1747 std::sort(removed_.begin(), removed_.end());
1748 for (std::vector<int64_t>::iterator it = removed_.begin();
1749 it != removed_.end(); ++it) {
1750 var->RemoveValue(*it);
1751 }
1752 }
1753
1754 void ClearRemovedValues() override { removed_.clear(); }
1755
1756 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1757 std::string out;
1758 DCHECK(bit(min));
1759 DCHECK(bit(max));
1760 if (max != min) {
1761 int cumul = true;
1762 int64_t start_cumul = min;
1763 for (int64_t v = min + 1; v < max; ++v) {
1764 if (bit(v)) {
1765 if (!cumul) {
1766 cumul = true;
1767 start_cumul = v;
1768 }
1769 } else {
1770 if (cumul) {
1771 if (v == start_cumul + 1) {
1772 absl::StrAppendFormat(&out, "%d ", start_cumul);
1773 } else if (v == start_cumul + 2) {
1774 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1775 } else {
1776 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1777 }
1778 cumul = false;
1779 }
1780 }
1781 }
1782 if (cumul) {
1783 if (max == start_cumul + 1) {
1784 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1785 } else {
1786 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1787 }
1788 } else {
1789 absl::StrAppendFormat(&out, "%d", max);
1790 }
1791 } else {
1792 absl::StrAppendFormat(&out, "%d", min);
1793 }
1794 return out;
1795 }
1796
1797 DomainIntVar::BitSetIterator* MakeIterator() override {
1798 return new DomainIntVar::BitSetIterator(bits_, omin_);
1799 }
1800
1801 private:
1802 uint64_t* bits_;
1803 uint64_t* stamps_;
1804 const int64_t omin_;
1805 const int64_t omax_;
1806 NumericalRev<int64_t> size_;
1807 const int bsize_;
1808 std::vector<int64_t> removed_;
1809};
1810
1811// This is a special case where the bitset fits into one 64 bit integer.
1812// In that case, there are no offset to compute.
1813// Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1814class SmallBitSet : public DomainIntVar::BitSet {
1815 public:
1816 SmallBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1817 : BitSet(s),
1818 bits_(uint64_t{0}),
1819 stamp_(s->stamp() - 1),
1820 omin_(vmin),
1821 omax_(vmax),
1822 size_(vmax - vmin + 1) {
1823 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1824 bits_ = OneRange64(0, size_.Value() - 1);
1825 }
1826
1827 SmallBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1828 int64_t vmin, int64_t vmax)
1829 : BitSet(s),
1830 bits_(uint64_t{0}),
1831 stamp_(s->stamp() - 1),
1832 omin_(vmin),
1833 omax_(vmax),
1834 size_(sorted_values.size()) {
1835 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1836 // We know the array is sorted and does not contains duplicate values.
1837 for (int i = 0; i < sorted_values.size(); ++i) {
1838 const int64_t val = sorted_values[i];
1839 DCHECK_GE(val, vmin);
1840 DCHECK_LE(val, vmax);
1841 DCHECK(!IsBitSet64(&bits_, val - omin_));
1842 bits_ |= OneBit64(val - omin_);
1843 }
1844 }
1845
1846 ~SmallBitSet() override {}
1847
1848 bool bit(int64_t val) const {
1849 DCHECK_GE(val, omin_);
1850 DCHECK_LE(val, omax_);
1851 return (bits_ & OneBit64(val - omin_)) != 0;
1852 }
1853
1854 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1855 DCHECK_GE(nmin, cmin);
1856 DCHECK_LE(nmin, cmax);
1857 DCHECK_LE(cmin, cmax);
1858 DCHECK_GE(cmin, omin_);
1859 DCHECK_LE(cmax, omax_);
1860 // We do not clean the bits between cmin and nmin.
1861 // But we use mask to look only at 'active' bits.
1862
1863 // Create the mask and compute new bits
1864 const uint64_t new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1865 if (new_bits != uint64_t{0}) {
1866 // Compute new size and new min
1867 size_.SetValue(solver_, BitCount64(new_bits));
1868 if (bit(nmin)) { // Common case, the new min is inside the bitset
1869 return nmin;
1870 }
1871 return LeastSignificantBitPosition64(new_bits) + omin_;
1872 } else { // == 0 -> Fail()
1873 solver_->Fail();
1874 return std::numeric_limits<int64_t>::max();
1875 }
1876 }
1877
1878 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1879 DCHECK_GE(nmax, cmin);
1880 DCHECK_LE(nmax, cmax);
1881 DCHECK_LE(cmin, cmax);
1882 DCHECK_GE(cmin, omin_);
1883 DCHECK_LE(cmax, omax_);
1884 // We do not clean the bits between nmax and cmax.
1885 // But we use mask to look only at 'active' bits.
1886
1887 // Create the mask and compute new_bits
1888 const uint64_t new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1889 if (new_bits != uint64_t{0}) {
1890 // Compute new size and new min
1891 size_.SetValue(solver_, BitCount64(new_bits));
1892 if (bit(nmax)) { // Common case, the new max is inside the bitset
1893 return nmax;
1894 }
1895 return MostSignificantBitPosition64(new_bits) + omin_;
1896 } else { // == 0 -> Fail()
1897 solver_->Fail();
1898 return std::numeric_limits<int64_t>::min();
1899 }
1900 }
1901
1902 bool SetValue(int64_t val) override {
1903 DCHECK_GE(val, omin_);
1904 DCHECK_LE(val, omax_);
1905 // We do not clean the bits. We will use masks to ignore the bits
1906 // that should have been cleaned.
1907 if (bit(val)) {
1908 size_.SetValue(solver_, 1);
1909 return true;
1910 }
1911 return false;
1912 }
1913
1914 bool Contains(int64_t val) const override {
1915 DCHECK_GE(val, omin_);
1916 DCHECK_LE(val, omax_);
1917 return bit(val);
1918 }
1919
1920 bool RemoveValue(int64_t val) override {
1921 DCHECK_GE(val, omin_);
1922 DCHECK_LE(val, omax_);
1923 if (bit(val)) {
1924 // Bitset.
1925 const uint64_t current_stamp = solver_->stamp();
1926 if (stamp_ < current_stamp) {
1927 stamp_ = current_stamp;
1928 solver_->SaveValue(&bits_);
1929 }
1930 bits_ &= ~OneBit64(val - omin_);
1931 DCHECK(!bit(val));
1932 // Size.
1933 size_.Decr(solver_);
1934 // Holes.
1935 InitHoles();
1936 AddHole(val);
1937 return true;
1938 } else {
1939 return false;
1940 }
1941 }
1942
1943 uint64_t Size() const override { return size_.Value(); }
1944
1945 std::string DebugString() const override {
1946 return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1947 }
1948
1949 void DelayRemoveValue(int64_t val) override {
1950 DCHECK_GE(val, omin_);
1951 DCHECK_LE(val, omax_);
1952 removed_.push_back(val);
1953 }
1954
1955 void ApplyRemovedValues(DomainIntVar* var) override {
1956 std::sort(removed_.begin(), removed_.end());
1957 for (std::vector<int64_t>::iterator it = removed_.begin();
1958 it != removed_.end(); ++it) {
1959 var->RemoveValue(*it);
1960 }
1961 }
1962
1963 void ClearRemovedValues() override { removed_.clear(); }
1964
1965 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1966 std::string out;
1967 DCHECK(bit(min));
1968 DCHECK(bit(max));
1969 if (max != min) {
1970 int cumul = true;
1971 int64_t start_cumul = min;
1972 for (int64_t v = min + 1; v < max; ++v) {
1973 if (bit(v)) {
1974 if (!cumul) {
1975 cumul = true;
1976 start_cumul = v;
1977 }
1978 } else {
1979 if (cumul) {
1980 if (v == start_cumul + 1) {
1981 absl::StrAppendFormat(&out, "%d ", start_cumul);
1982 } else if (v == start_cumul + 2) {
1983 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1984 } else {
1985 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1986 }
1987 cumul = false;
1988 }
1989 }
1990 }
1991 if (cumul) {
1992 if (max == start_cumul + 1) {
1993 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1994 } else {
1995 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1996 }
1997 } else {
1998 absl::StrAppendFormat(&out, "%d", max);
1999 }
2000 } else {
2001 absl::StrAppendFormat(&out, "%d", min);
2002 }
2003 return out;
2004 }
2005
2006 DomainIntVar::BitSetIterator* MakeIterator() override {
2007 return new DomainIntVar::BitSetIterator(&bits_, omin_);
2008 }
2009
2010 private:
2011 uint64_t bits_;
2012 uint64_t stamp_;
2013 const int64_t omin_;
2014 const int64_t omax_;
2015 NumericalRev<int64_t> size_;
2016 std::vector<int64_t> removed_;
2017};
2018
2019class EmptyIterator : public IntVarIterator {
2020 public:
2021 ~EmptyIterator() override {}
2022 void Init() override {}
2023 bool Ok() const override { return false; }
2024 int64_t Value() const override {
2025 LOG(FATAL) << "Should not be called";
2026 return 0LL;
2027 }
2028 void Next() override {}
2029};
2030
2031class RangeIterator : public IntVarIterator {
2032 public:
2033 explicit RangeIterator(const IntVar* const var)
2034 : var_(var),
2035 min_(std::numeric_limits<int64_t>::max()),
2036 max_(std::numeric_limits<int64_t>::min()),
2037 current_(-1) {}
2038
2039 ~RangeIterator() override {}
2040
2041 void Init() override {
2042 min_ = var_->Min();
2043 max_ = var_->Max();
2044 current_ = min_;
2045 }
2046
2047 bool Ok() const override { return current_ <= max_; }
2048
2049 int64_t Value() const override { return current_; }
2050
2051 void Next() override { current_++; }
2052
2053 private:
2054 const IntVar* const var_;
2055 int64_t min_;
2056 int64_t max_;
2057 int64_t current_;
2058};
2059
2060class DomainIntVarHoleIterator : public IntVarIterator {
2061 public:
2062 explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2063 : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2064
2065 ~DomainIntVarHoleIterator() override {}
2066
2067 void Init() override {
2068 bits_ = var_->bitset();
2069 if (bits_ != nullptr) {
2070 bits_->InitHoles();
2071 values_ = bits_->Holes().data();
2072 size_ = bits_->Holes().size();
2073 } else {
2074 values_ = nullptr;
2075 size_ = 0;
2076 }
2077 index_ = 0;
2078 }
2079
2080 bool Ok() const override { return index_ < size_; }
2081
2082 int64_t Value() const override {
2083 DCHECK(bits_ != nullptr);
2084 DCHECK(index_ < size_);
2085 return values_[index_];
2086 }
2087
2088 void Next() override { index_++; }
2089
2090 private:
2091 const DomainIntVar* const var_;
2092 DomainIntVar::BitSet* bits_;
2093 const int64_t* values_;
2094 int size_;
2095 int index_;
2096};
2097
2098class DomainIntVarDomainIterator : public IntVarIterator {
2099 public:
2100 explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2101 bool reversible)
2102 : var_(v),
2103 bitset_iterator_(nullptr),
2104 min_(std::numeric_limits<int64_t>::max()),
2105 max_(std::numeric_limits<int64_t>::min()),
2106 current_(-1),
2107 reversible_(reversible) {}
2108
2109 ~DomainIntVarDomainIterator() override {
2110 if (!reversible_ && bitset_iterator_) {
2111 delete bitset_iterator_;
2112 }
2113 }
2114
2115 void Init() override {
2116 if (var_->bitset() != nullptr && !var_->Bound()) {
2117 if (reversible_) {
2118 if (!bitset_iterator_) {
2119 Solver* const solver = var_->solver();
2120 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2121 bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2122 }
2123 } else {
2124 if (bitset_iterator_) {
2125 delete bitset_iterator_;
2126 }
2127 bitset_iterator_ = var_->bitset()->MakeIterator();
2128 }
2129 bitset_iterator_->Init(var_->Min(), var_->Max());
2130 } else {
2131 if (bitset_iterator_) {
2132 if (reversible_) {
2133 Solver* const solver = var_->solver();
2134 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2135 } else {
2136 delete bitset_iterator_;
2137 }
2138 bitset_iterator_ = nullptr;
2139 }
2140 min_ = var_->Min();
2141 max_ = var_->Max();
2142 current_ = min_;
2143 }
2144 }
2145
2146 bool Ok() const override {
2147 return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2148 }
2149
2150 int64_t Value() const override {
2151 return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2152 }
2153
2154 void Next() override {
2155 if (bitset_iterator_) {
2156 bitset_iterator_->Next();
2157 } else {
2158 current_++;
2159 }
2160 }
2161
2162 private:
2163 const DomainIntVar* const var_;
2164 DomainIntVar::BitSetIterator* bitset_iterator_;
2165 int64_t min_;
2166 int64_t max_;
2167 int64_t current_;
2168 const bool reversible_;
2169};
2170
2171class UnaryIterator : public IntVarIterator {
2172 public:
2173 UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2174 : iterator_(hole ? v->MakeHoleIterator(reversible)
2175 : v->MakeDomainIterator(reversible)),
2176 reversible_(reversible) {}
2177
2178 ~UnaryIterator() override {
2179 if (!reversible_) {
2180 delete iterator_;
2181 }
2182 }
2183
2184 void Init() override { iterator_->Init(); }
2185
2186 bool Ok() const override { return iterator_->Ok(); }
2187
2188 void Next() override { iterator_->Next(); }
2189
2190 protected:
2191 IntVarIterator* const iterator_;
2192 const bool reversible_;
2193};
2194
2195DomainIntVar::DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
2196 const std::string& name)
2197 : IntVar(s, name),
2198 min_(vmin),
2199 max_(vmax),
2200 old_min_(vmin),
2201 old_max_(vmax),
2202 new_min_(vmin),
2203 new_max_(vmax),
2204 handler_(this),
2205 in_process_(false),
2206 bits_(nullptr),
2207 value_watcher_(nullptr),
2208 bound_watcher_(nullptr) {}
2209
2210DomainIntVar::DomainIntVar(Solver* const s,
2211 absl::Span<const int64_t> sorted_values,
2212 const std::string& name)
2213 : IntVar(s, name),
2214 min_(std::numeric_limits<int64_t>::max()),
2215 max_(std::numeric_limits<int64_t>::min()),
2216 old_min_(std::numeric_limits<int64_t>::max()),
2217 old_max_(std::numeric_limits<int64_t>::min()),
2218 new_min_(std::numeric_limits<int64_t>::max()),
2219 new_max_(std::numeric_limits<int64_t>::min()),
2220 handler_(this),
2221 in_process_(false),
2222 bits_(nullptr),
2223 value_watcher_(nullptr),
2224 bound_watcher_(nullptr) {
2225 CHECK_GE(sorted_values.size(), 1);
2226 // We know that the vector is sorted and does not have duplicate values.
2227 const int64_t vmin = sorted_values.front();
2228 const int64_t vmax = sorted_values.back();
2229 const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2230
2231 min_.SetValue(solver(), vmin);
2232 old_min_ = vmin;
2233 new_min_ = vmin;
2234 max_.SetValue(solver(), vmax);
2235 old_max_ = vmax;
2236 new_max_ = vmax;
2237
2238 if (!contiguous) {
2239 if (vmax - vmin + 1 < 65) {
2240 bits_ = solver()->RevAlloc(
2241 new SmallBitSet(solver(), sorted_values, vmin, vmax));
2242 } else {
2243 bits_ = solver()->RevAlloc(
2244 new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2245 }
2246 }
2247}
2248
2249DomainIntVar::~DomainIntVar() {}
2250
2251void DomainIntVar::SetMin(int64_t m) {
2252 if (m <= min_.Value()) return;
2253 if (m > max_.Value()) solver()->Fail();
2254 if (in_process_) {
2255 if (m > new_min_) {
2256 new_min_ = m;
2257 if (new_min_ > new_max_) {
2258 solver()->Fail();
2259 }
2260 }
2261 } else {
2262 CheckOldMin();
2263 const int64_t new_min =
2264 (bits_ == nullptr
2265 ? m
2266 : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2267 min_.SetValue(solver(), new_min);
2268 if (min_.Value() > max_.Value()) {
2269 solver()->Fail();
2270 }
2271 Push();
2272 }
2273}
2274
2275void DomainIntVar::SetMax(int64_t m) {
2276 if (m >= max_.Value()) return;
2277 if (m < min_.Value()) solver()->Fail();
2278 if (in_process_) {
2279 if (m < new_max_) {
2280 new_max_ = m;
2281 if (new_max_ < new_min_) {
2282 solver()->Fail();
2283 }
2284 }
2285 } else {
2286 CheckOldMax();
2287 const int64_t new_max =
2288 (bits_ == nullptr
2289 ? m
2290 : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2291 max_.SetValue(solver(), new_max);
2292 if (min_.Value() > max_.Value()) {
2293 solver()->Fail();
2294 }
2295 Push();
2296 }
2297}
2298
2299void DomainIntVar::SetRange(int64_t mi, int64_t ma) {
2300 if (mi == ma) {
2301 SetValue(mi);
2302 } else {
2303 if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2304 if (mi <= min_.Value() && ma >= max_.Value()) return;
2305 if (in_process_) {
2306 if (ma < new_max_) {
2307 new_max_ = ma;
2308 }
2309 if (mi > new_min_) {
2310 new_min_ = mi;
2311 }
2312 if (new_min_ > new_max_) {
2313 solver()->Fail();
2314 }
2315 } else {
2316 if (mi > min_.Value()) {
2317 CheckOldMin();
2318 const int64_t new_min =
2319 (bits_ == nullptr
2320 ? mi
2321 : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2322 min_.SetValue(solver(), new_min);
2323 }
2324 if (min_.Value() > ma) {
2325 solver()->Fail();
2326 }
2327 if (ma < max_.Value()) {
2328 CheckOldMax();
2329 const int64_t new_max =
2330 (bits_ == nullptr
2331 ? ma
2332 : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2333 max_.SetValue(solver(), new_max);
2334 }
2335 if (min_.Value() > max_.Value()) {
2336 solver()->Fail();
2337 }
2338 Push();
2339 }
2340 }
2341}
2342
2343void DomainIntVar::SetValue(int64_t v) {
2344 if (v != min_.Value() || v != max_.Value()) {
2345 if (v < min_.Value() || v > max_.Value()) {
2346 solver()->Fail();
2347 }
2348 if (in_process_) {
2349 if (v > new_max_ || v < new_min_) {
2350 solver()->Fail();
2351 }
2352 new_min_ = v;
2353 new_max_ = v;
2354 } else {
2355 if (bits_ && !bits_->SetValue(v)) {
2356 solver()->Fail();
2357 }
2358 CheckOldMin();
2359 CheckOldMax();
2360 min_.SetValue(solver(), v);
2361 max_.SetValue(solver(), v);
2362 Push();
2363 }
2364 }
2365}
2366
2367void DomainIntVar::RemoveValue(int64_t v) {
2368 if (v < min_.Value() || v > max_.Value()) return;
2369 if (v == min_.Value()) {
2370 SetMin(v + 1);
2371 } else if (v == max_.Value()) {
2372 SetMax(v - 1);
2373 } else {
2374 if (bits_ == nullptr) {
2375 CreateBits();
2376 }
2377 if (in_process_) {
2378 if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2379 bits_->DelayRemoveValue(v);
2380 }
2381 } else {
2382 if (bits_->RemoveValue(v)) {
2383 Push();
2384 }
2385 }
2386 }
2387}
2388
2389void DomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2390 if (l <= min_.Value()) {
2391 SetMin(u + 1);
2392 } else if (u >= max_.Value()) {
2393 SetMax(l - 1);
2394 } else {
2395 for (int64_t v = l; v <= u; ++v) {
2396 RemoveValue(v);
2397 }
2398 }
2399}
2400
2401void DomainIntVar::CreateBits() {
2402 solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2403 if (max_.Value() - min_.Value() < 64) {
2404 bits_ = solver()->RevAlloc(
2405 new SmallBitSet(solver(), min_.Value(), max_.Value()));
2406 } else {
2407 bits_ = solver()->RevAlloc(
2408 new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2409 }
2410}
2411
2412void DomainIntVar::CleanInProcess() {
2413 in_process_ = false;
2414 if (bits_ != nullptr) {
2415 bits_->ClearHoles();
2416 }
2417}
2418
2419void DomainIntVar::Push() {
2420 const bool in_process = in_process_;
2421 EnqueueVar(&handler_);
2422 CHECK_EQ(in_process, in_process_);
2423}
2424
2425void DomainIntVar::Process() {
2426 CHECK(!in_process_);
2427 in_process_ = true;
2428 if (bits_ != nullptr) {
2429 bits_->ClearRemovedValues();
2430 }
2431 set_variable_to_clean_on_fail(this);
2432 new_min_ = min_.Value();
2433 new_max_ = max_.Value();
2434 const bool is_bound = min_.Value() == max_.Value();
2435 const bool range_changed =
2436 min_.Value() != OldMin() || max_.Value() != OldMax();
2437 // Process immediate demons.
2438 if (is_bound) {
2439 ExecuteAll(bound_demons_);
2440 }
2441 if (range_changed) {
2442 ExecuteAll(range_demons_);
2443 }
2444 ExecuteAll(domain_demons_);
2445
2446 // Process delayed demons.
2447 if (is_bound) {
2448 EnqueueAll(delayed_bound_demons_);
2449 }
2450 if (range_changed) {
2451 EnqueueAll(delayed_range_demons_);
2452 }
2453 EnqueueAll(delayed_domain_demons_);
2454
2455 // Everything went well if we arrive here. Let's clean the variable.
2456 set_variable_to_clean_on_fail(nullptr);
2457 CleanInProcess();
2458 old_min_ = min_.Value();
2459 old_max_ = max_.Value();
2460 if (min_.Value() < new_min_) {
2461 SetMin(new_min_);
2462 }
2463 if (max_.Value() > new_max_) {
2464 SetMax(new_max_);
2465 }
2466 if (bits_ != nullptr) {
2467 bits_->ApplyRemovedValues(this);
2468 }
2469}
2470
2471template <typename T>
2472T* CondRevAlloc(Solver* solver, bool reversible, T* object) {
2473 return reversible ? solver->RevAlloc(object) : object;
2474}
2475
2476IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2477 return CondRevAlloc(solver(), reversible, new DomainIntVarHoleIterator(this));
2478}
2479
2480IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2481 return CondRevAlloc(solver(), reversible,
2482 new DomainIntVarDomainIterator(this, reversible));
2483}
2484
2485std::string DomainIntVar::DebugString() const {
2486 std::string out;
2487 const std::string& var_name = name();
2488 if (!var_name.empty()) {
2489 out = var_name + "(";
2490 } else {
2491 out = "DomainIntVar(";
2492 }
2493 if (min_.Value() == max_.Value()) {
2494 absl::StrAppendFormat(&out, "%d", min_.Value());
2495 } else if (bits_ != nullptr) {
2496 out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2497 } else {
2498 absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2499 }
2500 out += ")";
2501 return out;
2502}
2503
2504// ----- Real Boolean Var -----
2505
2506class ConcreteBooleanVar : public BooleanVar {
2507 public:
2508 // Utility classes
2509 class Handler : public Demon {
2510 public:
2511 explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2512 ~Handler() override {}
2513 void Run(Solver* const s) override {
2514 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2515 var_->Process();
2516 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2517 }
2518 Solver::DemonPriority priority() const override {
2519 return Solver::VAR_PRIORITY;
2520 }
2521 std::string DebugString() const override {
2522 return absl::StrFormat("Handler(%s)", var_->DebugString());
2523 }
2524
2525 private:
2526 ConcreteBooleanVar* const var_;
2527 };
2528
2529 ConcreteBooleanVar(Solver* const s, const std::string& name)
2530 : BooleanVar(s, name), handler_(this) {}
2531
2532 ~ConcreteBooleanVar() override {}
2533
2534 void SetValue(int64_t v) override {
2535 if (value_ == kUnboundBooleanVarValue) {
2536 if ((v & 0xfffffffffffffffe) == 0) {
2537 InternalSaveBooleanVarValue(solver(), this);
2538 value_ = static_cast<int>(v);
2539 EnqueueVar(&handler_);
2540 return;
2541 }
2542 } else if (v == value_) {
2543 return;
2544 }
2545 solver()->Fail();
2546 }
2547
2548 void Process() {
2549 DCHECK_NE(value_, kUnboundBooleanVarValue);
2550 ExecuteAll(bound_demons_);
2551 for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2552 ++it) {
2553 EnqueueDelayedDemon(*it);
2554 }
2555 }
2556
2557 int64_t OldMin() const override { return 0LL; }
2558 int64_t OldMax() const override { return 1LL; }
2559 void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2560
2561 private:
2562 Handler handler_;
2563};
2564
2565// ----- IntConst -----
2566
2567class IntConst : public IntVar {
2568 public:
2569 IntConst(Solver* const s, int64_t value, const std::string& name = "")
2570 : IntVar(s, name), value_(value) {}
2571 ~IntConst() override {}
2572
2573 int64_t Min() const override { return value_; }
2574 void SetMin(int64_t m) override {
2575 if (m > value_) {
2576 solver()->Fail();
2577 }
2578 }
2579 int64_t Max() const override { return value_; }
2580 void SetMax(int64_t m) override {
2581 if (m < value_) {
2582 solver()->Fail();
2583 }
2584 }
2585 void SetRange(int64_t l, int64_t u) override {
2586 if (l > value_ || u < value_) {
2587 solver()->Fail();
2588 }
2589 }
2590 void SetValue(int64_t v) override {
2591 if (v != value_) {
2592 solver()->Fail();
2593 }
2594 }
2595 bool Bound() const override { return true; }
2596 int64_t Value() const override { return value_; }
2597 void RemoveValue(int64_t v) override {
2598 if (v == value_) {
2599 solver()->Fail();
2600 }
2601 }
2602 void RemoveInterval(int64_t l, int64_t u) override {
2603 if (l <= value_ && value_ <= u) {
2604 solver()->Fail();
2605 }
2606 }
2607 void WhenBound(Demon* d) override {}
2608 void WhenRange(Demon* d) override {}
2609 void WhenDomain(Demon* d) override {}
2610 uint64_t Size() const override { return 1; }
2611 bool Contains(int64_t v) const override { return (v == value_); }
2612 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2613 return CondRevAlloc(solver(), reversible, new EmptyIterator());
2614 }
2615 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2616 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
2617 }
2618 int64_t OldMin() const override { return value_; }
2619 int64_t OldMax() const override { return value_; }
2620 std::string DebugString() const override {
2621 std::string out;
2622 if (solver()->HasName(this)) {
2623 const std::string& var_name = name();
2624 absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2625 } else {
2626 absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2627 }
2628 return out;
2629 }
2630
2631 int VarType() const override { return CONST_VAR; }
2632
2633 IntVar* IsEqual(int64_t constant) override {
2634 if (constant == value_) {
2635 return solver()->MakeIntConst(1);
2636 } else {
2637 return solver()->MakeIntConst(0);
2638 }
2639 }
2640
2641 IntVar* IsDifferent(int64_t constant) override {
2642 if (constant == value_) {
2643 return solver()->MakeIntConst(0);
2644 } else {
2645 return solver()->MakeIntConst(1);
2646 }
2647 }
2648
2649 IntVar* IsGreaterOrEqual(int64_t constant) override {
2650 return solver()->MakeIntConst(value_ >= constant);
2651 }
2652
2653 IntVar* IsLessOrEqual(int64_t constant) override {
2654 return solver()->MakeIntConst(value_ <= constant);
2655 }
2656
2657 std::string name() const override {
2658 if (solver()->HasName(this)) {
2659 return PropagationBaseObject::name();
2660 } else {
2661 return absl::StrCat(value_);
2662 }
2663 }
2664
2665 private:
2666 int64_t value_;
2667};
2668
2669// ----- x + c variable, optimized case -----
2670
2671class PlusCstVar : public IntVar {
2672 public:
2673 PlusCstVar(Solver* const s, IntVar* v, int64_t c)
2674 : IntVar(s), var_(v), cst_(c) {}
2675
2676 ~PlusCstVar() override {}
2677
2678 void WhenRange(Demon* d) override { var_->WhenRange(d); }
2679
2680 void WhenBound(Demon* d) override { var_->WhenBound(d); }
2681
2682 void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2683
2684 int64_t OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2685
2686 int64_t OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2687
2688 std::string DebugString() const override {
2689 if (HasName()) {
2690 return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2691 } else {
2692 return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2693 }
2694 }
2695
2696 int VarType() const override { return VAR_ADD_CST; }
2697
2698 void Accept(ModelVisitor* const visitor) const override {
2699 visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2700 var_);
2701 }
2702
2703 IntVar* IsEqual(int64_t constant) override {
2704 return var_->IsEqual(constant - cst_);
2705 }
2706
2707 IntVar* IsDifferent(int64_t constant) override {
2708 return var_->IsDifferent(constant - cst_);
2709 }
2710
2711 IntVar* IsGreaterOrEqual(int64_t constant) override {
2712 return var_->IsGreaterOrEqual(constant - cst_);
2713 }
2714
2715 IntVar* IsLessOrEqual(int64_t constant) override {
2716 return var_->IsLessOrEqual(constant - cst_);
2717 }
2718
2719 IntVar* SubVar() const { return var_; }
2720
2721 int64_t Constant() const { return cst_; }
2722
2723 protected:
2724 IntVar* const var_;
2725 const int64_t cst_;
2726};
2727
2728class PlusCstIntVar : public PlusCstVar {
2729 public:
2730 class PlusCstIntVarIterator : public UnaryIterator {
2731 public:
2732 PlusCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2733 : UnaryIterator(v, hole, rev), cst_(c) {}
2734
2735 ~PlusCstIntVarIterator() override {}
2736
2737 int64_t Value() const override { return iterator_->Value() + cst_; }
2738
2739 private:
2740 const int64_t cst_;
2741 };
2742
2743 PlusCstIntVar(Solver* const s, IntVar* v, int64_t c) : PlusCstVar(s, v, c) {}
2744
2745 ~PlusCstIntVar() override {}
2746
2747 int64_t Min() const override { return var_->Min() + cst_; }
2748
2749 void SetMin(int64_t m) override { var_->SetMin(CapSub(m, cst_)); }
2750
2751 int64_t Max() const override { return var_->Max() + cst_; }
2752
2753 void SetMax(int64_t m) override { var_->SetMax(CapSub(m, cst_)); }
2754
2755 void SetRange(int64_t l, int64_t u) override {
2756 var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2757 }
2758
2759 void SetValue(int64_t v) override { var_->SetValue(v - cst_); }
2760
2761 int64_t Value() const override { return var_->Value() + cst_; }
2762
2763 bool Bound() const override { return var_->Bound(); }
2764
2765 void RemoveValue(int64_t v) override { var_->RemoveValue(v - cst_); }
2766
2767 void RemoveInterval(int64_t l, int64_t u) override {
2768 var_->RemoveInterval(l - cst_, u - cst_);
2769 }
2770
2771 uint64_t Size() const override { return var_->Size(); }
2772
2773 bool Contains(int64_t v) const override { return var_->Contains(v - cst_); }
2774
2775 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2776 return CondRevAlloc(
2777 solver(), reversible,
2778 new PlusCstIntVarIterator(var_, cst_, true, reversible));
2779 }
2780 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2781 return CondRevAlloc(
2782 solver(), reversible,
2783 new PlusCstIntVarIterator(var_, cst_, false, reversible));
2784 }
2785};
2786
2787class PlusCstDomainIntVar : public PlusCstVar {
2788 public:
2789 class PlusCstDomainIntVarIterator : public UnaryIterator {
2790 public:
2791 PlusCstDomainIntVarIterator(const IntVar* const v, int64_t c, bool hole,
2792 bool reversible)
2793 : UnaryIterator(v, hole, reversible), cst_(c) {}
2794
2795 ~PlusCstDomainIntVarIterator() override {}
2796
2797 int64_t Value() const override { return iterator_->Value() + cst_; }
2798
2799 private:
2800 const int64_t cst_;
2801 };
2802
2803 PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64_t c)
2804 : PlusCstVar(s, v, c) {}
2805
2806 ~PlusCstDomainIntVar() override {}
2807
2808 int64_t Min() const override;
2809 void SetMin(int64_t m) override;
2810 int64_t Max() const override;
2811 void SetMax(int64_t m) override;
2812 void SetRange(int64_t l, int64_t u) override;
2813 void SetValue(int64_t v) override;
2814 bool Bound() const override;
2815 int64_t Value() const override;
2816 void RemoveValue(int64_t v) override;
2817 void RemoveInterval(int64_t l, int64_t u) override;
2818 uint64_t Size() const override;
2819 bool Contains(int64_t v) const override;
2820
2821 DomainIntVar* domain_int_var() const {
2822 return reinterpret_cast<DomainIntVar*>(var_);
2823 }
2824
2825 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2826 return CondRevAlloc(
2827 solver(), reversible,
2828 new PlusCstDomainIntVarIterator(var_, cst_, true, reversible));
2829 }
2830 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2831 return CondRevAlloc(
2832 solver(), reversible,
2833 new PlusCstDomainIntVarIterator(var_, cst_, false, reversible));
2834 }
2835};
2836
2837int64_t PlusCstDomainIntVar::Min() const {
2838 return domain_int_var()->min_.Value() + cst_;
2839}
2840
2841void PlusCstDomainIntVar::SetMin(int64_t m) {
2842 domain_int_var()->DomainIntVar::SetMin(CapSub(m, cst_));
2843}
2844
2845int64_t PlusCstDomainIntVar::Max() const {
2846 return domain_int_var()->max_.Value() + cst_;
2847}
2848
2849void PlusCstDomainIntVar::SetMax(int64_t m) {
2850 domain_int_var()->DomainIntVar::SetMax(CapSub(m, cst_));
2851}
2852
2853void PlusCstDomainIntVar::SetRange(int64_t l, int64_t u) {
2854 domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2855}
2856
2857void PlusCstDomainIntVar::SetValue(int64_t v) {
2858 domain_int_var()->DomainIntVar::SetValue(v - cst_);
2859}
2860
2861bool PlusCstDomainIntVar::Bound() const {
2862 return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2863}
2864
2865int64_t PlusCstDomainIntVar::Value() const {
2866 CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2867 << " variable is not bound";
2868 return domain_int_var()->min_.Value() + cst_;
2869}
2870
2871void PlusCstDomainIntVar::RemoveValue(int64_t v) {
2872 domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2873}
2874
2875void PlusCstDomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2876 domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2877}
2878
2879uint64_t PlusCstDomainIntVar::Size() const {
2880 return domain_int_var()->DomainIntVar::Size();
2881}
2882
2883bool PlusCstDomainIntVar::Contains(int64_t v) const {
2884 return domain_int_var()->DomainIntVar::Contains(v - cst_);
2885}
2886
2887// c - x variable, optimized case
2888
2889class SubCstIntVar : public IntVar {
2890 public:
2891 class SubCstIntVarIterator : public UnaryIterator {
2892 public:
2893 SubCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2894 : UnaryIterator(v, hole, rev), cst_(c) {}
2895 ~SubCstIntVarIterator() override {}
2896
2897 int64_t Value() const override { return cst_ - iterator_->Value(); }
2898
2899 private:
2900 const int64_t cst_;
2901 };
2902
2903 SubCstIntVar(Solver* s, IntVar* v, int64_t c);
2904 ~SubCstIntVar() override;
2905
2906 int64_t Min() const override;
2907 void SetMin(int64_t m) override;
2908 int64_t Max() const override;
2909 void SetMax(int64_t m) override;
2910 void SetRange(int64_t l, int64_t u) override;
2911 void SetValue(int64_t v) override;
2912 bool Bound() const override;
2913 int64_t Value() const override;
2914 void RemoveValue(int64_t v) override;
2915 void RemoveInterval(int64_t l, int64_t u) override;
2916 uint64_t Size() const override;
2917 bool Contains(int64_t v) const override;
2918 void WhenRange(Demon* d) override;
2919 void WhenBound(Demon* d) override;
2920 void WhenDomain(Demon* d) override;
2921 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2922 return CondRevAlloc(solver(), reversible,
2923 new SubCstIntVarIterator(var_, cst_, true, reversible));
2924 }
2925 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2926 return CondRevAlloc(
2927 solver(), reversible,
2928 new SubCstIntVarIterator(var_, cst_, false, reversible));
2929 }
2930 int64_t OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2931 int64_t OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2932 std::string DebugString() const override;
2933 std::string name() const override;
2934 int VarType() const override { return CST_SUB_VAR; }
2935
2936 void Accept(ModelVisitor* const visitor) const override {
2937 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2938 cst_, var_);
2939 }
2940
2941 IntVar* IsEqual(int64_t constant) override {
2942 return var_->IsEqual(cst_ - constant);
2943 }
2944
2945 IntVar* IsDifferent(int64_t constant) override {
2946 return var_->IsDifferent(cst_ - constant);
2947 }
2948
2949 IntVar* IsGreaterOrEqual(int64_t constant) override {
2950 return var_->IsLessOrEqual(cst_ - constant);
2951 }
2952
2953 IntVar* IsLessOrEqual(int64_t constant) override {
2954 return var_->IsGreaterOrEqual(cst_ - constant);
2955 }
2956
2957 IntVar* SubVar() const { return var_; }
2958 int64_t Constant() const { return cst_; }
2959
2960 private:
2961 IntVar* const var_;
2962 const int64_t cst_;
2963};
2964
2965SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64_t c)
2966 : IntVar(s), var_(v), cst_(c) {}
2967
2968SubCstIntVar::~SubCstIntVar() {}
2969
2970int64_t SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2971
2972void SubCstIntVar::SetMin(int64_t m) { var_->SetMax(CapSub(cst_, m)); }
2973
2974int64_t SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2975
2976void SubCstIntVar::SetMax(int64_t m) { var_->SetMin(CapSub(cst_, m)); }
2977
2978void SubCstIntVar::SetRange(int64_t l, int64_t u) {
2979 var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2980}
2981
2982void SubCstIntVar::SetValue(int64_t v) { var_->SetValue(cst_ - v); }
2983
2984bool SubCstIntVar::Bound() const { return var_->Bound(); }
2985
2986void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2987
2988int64_t SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2989
2990void SubCstIntVar::RemoveValue(int64_t v) { var_->RemoveValue(cst_ - v); }
2991
2992void SubCstIntVar::RemoveInterval(int64_t l, int64_t u) {
2993 var_->RemoveInterval(cst_ - u, cst_ - l);
2994}
2995
2996void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2997
2998void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2999
3000uint64_t SubCstIntVar::Size() const { return var_->Size(); }
3001
3002bool SubCstIntVar::Contains(int64_t v) const {
3003 return var_->Contains(cst_ - v);
3004}
3005
3006std::string SubCstIntVar::DebugString() const {
3007 if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3008 return absl::StrFormat("Not(%s)", var_->DebugString());
3009 } else {
3010 return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
3011 }
3012}
3013
3014std::string SubCstIntVar::name() const {
3015 if (solver()->HasName(this)) {
3016 return PropagationBaseObject::name();
3017 } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3018 return absl::StrFormat("Not(%s)", var_->name());
3019 } else {
3020 return absl::StrFormat("(%d - %s)", cst_, var_->name());
3021 }
3022}
3023
3024// -x variable, optimized case
3025
3026class OppIntVar : public IntVar {
3027 public:
3028 class OppIntVarIterator : public UnaryIterator {
3029 public:
3030 OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3031 : UnaryIterator(v, hole, reversible) {}
3032 ~OppIntVarIterator() override {}
3033
3034 int64_t Value() const override { return -iterator_->Value(); }
3035 };
3036
3037 OppIntVar(Solver* s, IntVar* v);
3038 ~OppIntVar() override;
3039
3040 int64_t Min() const override;
3041 void SetMin(int64_t m) override;
3042 int64_t Max() const override;
3043 void SetMax(int64_t m) override;
3044 void SetRange(int64_t l, int64_t u) override;
3045 void SetValue(int64_t v) override;
3046 bool Bound() const override;
3047 int64_t Value() const override;
3048 void RemoveValue(int64_t v) override;
3049 void RemoveInterval(int64_t l, int64_t u) override;
3050 uint64_t Size() const override;
3051 bool Contains(int64_t v) const override;
3052 void WhenRange(Demon* d) override;
3053 void WhenBound(Demon* d) override;
3054 void WhenDomain(Demon* d) override;
3055 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3056 return CondRevAlloc(solver(), reversible,
3057 new OppIntVarIterator(var_, true, reversible));
3058 }
3059 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3060 return CondRevAlloc(solver(), reversible,
3061 new OppIntVarIterator(var_, false, reversible));
3062 }
3063 int64_t OldMin() const override { return CapOpp(var_->OldMax()); }
3064 int64_t OldMax() const override { return CapOpp(var_->OldMin()); }
3065 std::string DebugString() const override;
3066 int VarType() const override { return OPP_VAR; }
3067
3068 void Accept(ModelVisitor* const visitor) const override {
3069 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3070 var_);
3071 }
3072
3073 IntVar* IsEqual(int64_t constant) override {
3074 return var_->IsEqual(-constant);
3075 }
3076
3077 IntVar* IsDifferent(int64_t constant) override {
3078 return var_->IsDifferent(-constant);
3079 }
3080
3081 IntVar* IsGreaterOrEqual(int64_t constant) override {
3082 return var_->IsLessOrEqual(-constant);
3083 }
3084
3085 IntVar* IsLessOrEqual(int64_t constant) override {
3086 return var_->IsGreaterOrEqual(-constant);
3087 }
3088
3089 IntVar* SubVar() const { return var_; }
3090
3091 private:
3092 IntVar* const var_;
3093};
3094
3095OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3096
3097OppIntVar::~OppIntVar() {}
3098
3099int64_t OppIntVar::Min() const { return -var_->Max(); }
3100
3101void OppIntVar::SetMin(int64_t m) { var_->SetMax(CapOpp(m)); }
3102
3103int64_t OppIntVar::Max() const { return -var_->Min(); }
3104
3105void OppIntVar::SetMax(int64_t m) { var_->SetMin(CapOpp(m)); }
3106
3107void OppIntVar::SetRange(int64_t l, int64_t u) {
3108 var_->SetRange(CapOpp(u), CapOpp(l));
3109}
3110
3111void OppIntVar::SetValue(int64_t v) { var_->SetValue(CapOpp(v)); }
3112
3113bool OppIntVar::Bound() const { return var_->Bound(); }
3114
3115void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3116
3117int64_t OppIntVar::Value() const { return -var_->Value(); }
3118
3119void OppIntVar::RemoveValue(int64_t v) { var_->RemoveValue(-v); }
3120
3121void OppIntVar::RemoveInterval(int64_t l, int64_t u) {
3122 var_->RemoveInterval(-u, -l);
3123}
3124
3125void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3126
3127void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3128
3129uint64_t OppIntVar::Size() const { return var_->Size(); }
3130
3131bool OppIntVar::Contains(int64_t v) const { return var_->Contains(-v); }
3132
3133std::string OppIntVar::DebugString() const {
3134 return absl::StrFormat("-(%s)", var_->DebugString());
3135}
3136
3137// ----- Utility functions -----
3138
3139// x * c variable, optimized case
3140
3141class TimesCstIntVar : public IntVar {
3142 public:
3143 TimesCstIntVar(Solver* const s, IntVar* v, int64_t c)
3144 : IntVar(s), var_(v), cst_(c) {}
3145 ~TimesCstIntVar() override {}
3146
3147 IntVar* SubVar() const { return var_; }
3148 int64_t Constant() const { return cst_; }
3149
3150 void Accept(ModelVisitor* const visitor) const override {
3151 visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3152 var_);
3153 }
3154
3155 IntVar* IsEqual(int64_t constant) override {
3156 if (constant % cst_ == 0) {
3157 return var_->IsEqual(constant / cst_);
3158 } else {
3159 return solver()->MakeIntConst(0);
3160 }
3161 }
3162
3163 IntVar* IsDifferent(int64_t constant) override {
3164 if (constant % cst_ == 0) {
3165 return var_->IsDifferent(constant / cst_);
3166 } else {
3167 return solver()->MakeIntConst(1);
3168 }
3169 }
3170
3171 IntVar* IsGreaterOrEqual(int64_t constant) override {
3172 if (cst_ > 0) {
3173 return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3174 } else {
3175 return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3176 }
3177 }
3178
3179 IntVar* IsLessOrEqual(int64_t constant) override {
3180 if (cst_ > 0) {
3181 return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3182 } else {
3183 return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3184 }
3185 }
3186
3187 std::string DebugString() const override {
3188 return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3189 }
3190
3191 int VarType() const override { return VAR_TIMES_CST; }
3192
3193 protected:
3194 IntVar* const var_;
3195 const int64_t cst_;
3196};
3197
3198class TimesPosCstIntVar : public TimesCstIntVar {
3199 public:
3200 class TimesPosCstIntVarIterator : public UnaryIterator {
3201 public:
3202 TimesPosCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3203 bool reversible)
3204 : UnaryIterator(v, hole, reversible), cst_(c) {}
3205 ~TimesPosCstIntVarIterator() override {}
3206
3207 int64_t Value() const override { return iterator_->Value() * cst_; }
3208
3209 private:
3210 const int64_t cst_;
3211 };
3212
3213 TimesPosCstIntVar(Solver* s, IntVar* v, int64_t c);
3214 ~TimesPosCstIntVar() override;
3215
3216 int64_t Min() const override;
3217 void SetMin(int64_t m) override;
3218 int64_t Max() const override;
3219 void SetMax(int64_t m) override;
3220 void SetRange(int64_t l, int64_t u) override;
3221 void SetValue(int64_t v) override;
3222 bool Bound() const override;
3223 int64_t Value() const override;
3224 void RemoveValue(int64_t v) override;
3225 void RemoveInterval(int64_t l, int64_t u) override;
3226 uint64_t Size() const override;
3227 bool Contains(int64_t v) const override;
3228 void WhenRange(Demon* d) override;
3229 void WhenBound(Demon* d) override;
3230 void WhenDomain(Demon* d) override;
3231 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3232 return CondRevAlloc(
3233 solver(), reversible,
3234 new TimesPosCstIntVarIterator(var_, cst_, true, reversible));
3235 }
3236 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3237 return CondRevAlloc(
3238 solver(), reversible,
3239 new TimesPosCstIntVarIterator(var_, cst_, false, reversible));
3240 }
3241 int64_t OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3242 int64_t OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3243};
3244
3245// ----- TimesPosCstIntVar -----
3246
3247TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c)
3248 : TimesCstIntVar(s, v, c) {}
3249
3250TimesPosCstIntVar::~TimesPosCstIntVar() {}
3251
3252int64_t TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3253
3254void TimesPosCstIntVar::SetMin(int64_t m) {
3255 if (m != std::numeric_limits<int64_t>::min()) {
3256 var_->SetMin(PosIntDivUp(m, cst_));
3257 }
3258}
3259
3260int64_t TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3261
3262void TimesPosCstIntVar::SetMax(int64_t m) {
3263 if (m != std::numeric_limits<int64_t>::max()) {
3264 var_->SetMax(PosIntDivDown(m, cst_));
3265 }
3266}
3267
3268void TimesPosCstIntVar::SetRange(int64_t l, int64_t u) {
3269 var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3270}
3271
3272void TimesPosCstIntVar::SetValue(int64_t v) {
3273 if (v % cst_ != 0) {
3274 solver()->Fail();
3275 }
3276 var_->SetValue(v / cst_);
3277}
3278
3279bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3280
3281void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3282
3283int64_t TimesPosCstIntVar::Value() const {
3284 return CapProd(var_->Value(), cst_);
3285}
3286
3287void TimesPosCstIntVar::RemoveValue(int64_t v) {
3288 if (v % cst_ == 0) {
3289 var_->RemoveValue(v / cst_);
3290 }
3291}
3292
3293void TimesPosCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3294 for (int64_t v = l; v <= u; ++v) {
3295 RemoveValue(v);
3296 }
3297 // TODO(user) : Improve me
3298}
3299
3300void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3301
3302void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3303
3304uint64_t TimesPosCstIntVar::Size() const { return var_->Size(); }
3305
3306bool TimesPosCstIntVar::Contains(int64_t v) const {
3307 return (v % cst_ == 0 && var_->Contains(v / cst_));
3308}
3309
3310// b * c variable, optimized case
3311
3312class TimesPosCstBoolVar : public TimesCstIntVar {
3313 public:
3314 class TimesPosCstBoolVarIterator : public UnaryIterator {
3315 public:
3316 // TODO(user) : optimize this.
3317 TimesPosCstBoolVarIterator(const IntVar* const v, int64_t c, bool hole,
3318 bool reversible)
3319 : UnaryIterator(v, hole, reversible), cst_(c) {}
3320 ~TimesPosCstBoolVarIterator() override {}
3321
3322 int64_t Value() const override { return iterator_->Value() * cst_; }
3323
3324 private:
3325 const int64_t cst_;
3326 };
3327
3328 TimesPosCstBoolVar(Solver* s, BooleanVar* v, int64_t c);
3329 ~TimesPosCstBoolVar() override;
3330
3331 int64_t Min() const override;
3332 void SetMin(int64_t m) override;
3333 int64_t Max() const override;
3334 void SetMax(int64_t m) override;
3335 void SetRange(int64_t l, int64_t u) override;
3336 void SetValue(int64_t v) override;
3337 bool Bound() const override;
3338 int64_t Value() const override;
3339 void RemoveValue(int64_t v) override;
3340 void RemoveInterval(int64_t l, int64_t u) override;
3341 uint64_t Size() const override;
3342 bool Contains(int64_t v) const override;
3343 void WhenRange(Demon* d) override;
3344 void WhenBound(Demon* d) override;
3345 void WhenDomain(Demon* d) override;
3346 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3347 return CondRevAlloc(solver(), reversible, new EmptyIterator());
3348 }
3349 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3350 return CondRevAlloc(
3351 solver(), reversible,
3352 new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3353 }
3354 int64_t OldMin() const override { return 0; }
3355 int64_t OldMax() const override { return cst_; }
3356
3357 BooleanVar* boolean_var() const {
3358 return reinterpret_cast<BooleanVar*>(var_);
3359 }
3360};
3361
3362// ----- TimesPosCstBoolVar -----
3363
3364TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v,
3365 int64_t c)
3366 : TimesCstIntVar(s, v, c) {}
3367
3368TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3369
3370int64_t TimesPosCstBoolVar::Min() const {
3371 return (boolean_var()->RawValue() == 1) * cst_;
3372}
3373
3374void TimesPosCstBoolVar::SetMin(int64_t m) {
3375 if (m > cst_) {
3376 solver()->Fail();
3377 } else if (m > 0) {
3378 boolean_var()->SetMin(1);
3379 }
3380}
3381
3382int64_t TimesPosCstBoolVar::Max() const {
3383 return (boolean_var()->RawValue() != 0) * cst_;
3384}
3385
3386void TimesPosCstBoolVar::SetMax(int64_t m) {
3387 if (m < 0) {
3388 solver()->Fail();
3389 } else if (m < cst_) {
3390 boolean_var()->SetMax(0);
3391 }
3392}
3393
3394void TimesPosCstBoolVar::SetRange(int64_t l, int64_t u) {
3395 if (u < 0 || l > cst_ || l > u) {
3396 solver()->Fail();
3397 }
3398 if (l > 0) {
3399 boolean_var()->SetMin(1);
3400 } else if (u < cst_) {
3401 boolean_var()->SetMax(0);
3402 }
3403}
3404
3405void TimesPosCstBoolVar::SetValue(int64_t v) {
3406 if (v == 0) {
3407 boolean_var()->SetValue(0);
3408 } else if (v == cst_) {
3409 boolean_var()->SetValue(1);
3410 } else {
3411 solver()->Fail();
3412 }
3413}
3414
3415bool TimesPosCstBoolVar::Bound() const {
3416 return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3417}
3418
3419void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3420
3421int64_t TimesPosCstBoolVar::Value() const {
3422 CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3423 << " variable is not bound";
3424 return boolean_var()->RawValue() * cst_;
3425}
3426
3427void TimesPosCstBoolVar::RemoveValue(int64_t v) {
3428 if (v == 0) {
3429 boolean_var()->RemoveValue(0);
3430 } else if (v == cst_) {
3431 boolean_var()->RemoveValue(1);
3432 }
3433}
3434
3435void TimesPosCstBoolVar::RemoveInterval(int64_t l, int64_t u) {
3436 if (l <= 0 && u >= 0) {
3437 boolean_var()->RemoveValue(0);
3438 }
3439 if (l <= cst_ && u >= cst_) {
3440 boolean_var()->RemoveValue(1);
3441 }
3442}
3443
3444void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3445
3446void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3447
3448uint64_t TimesPosCstBoolVar::Size() const {
3449 return (1 +
3450 (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3451}
3452
3453bool TimesPosCstBoolVar::Contains(int64_t v) const {
3454 if (v == 0) {
3455 return boolean_var()->RawValue() != 1;
3456 } else if (v == cst_) {
3457 return boolean_var()->RawValue() != 0;
3458 }
3459 return false;
3460}
3461
3462// TimesNegCstIntVar
3463
3464class TimesNegCstIntVar : public TimesCstIntVar {
3465 public:
3466 class TimesNegCstIntVarIterator : public UnaryIterator {
3467 public:
3468 TimesNegCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3469 bool reversible)
3470 : UnaryIterator(v, hole, reversible), cst_(c) {}
3471 ~TimesNegCstIntVarIterator() override {}
3472
3473 int64_t Value() const override { return iterator_->Value() * cst_; }
3474
3475 private:
3476 const int64_t cst_;
3477 };
3478
3479 TimesNegCstIntVar(Solver* s, IntVar* v, int64_t c);
3480 ~TimesNegCstIntVar() override;
3481
3482 int64_t Min() const override;
3483 void SetMin(int64_t m) override;
3484 int64_t Max() const override;
3485 void SetMax(int64_t m) override;
3486 void SetRange(int64_t l, int64_t u) override;
3487 void SetValue(int64_t v) override;
3488 bool Bound() const override;
3489 int64_t Value() const override;
3490 void RemoveValue(int64_t v) override;
3491 void RemoveInterval(int64_t l, int64_t u) override;
3492 uint64_t Size() const override;
3493 bool Contains(int64_t v) const override;
3494 void WhenRange(Demon* d) override;
3495 void WhenBound(Demon* d) override;
3496 void WhenDomain(Demon* d) override;
3497 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3498 return CondRevAlloc(
3499 solver(), reversible,
3500 new TimesNegCstIntVarIterator(var_, cst_, true, reversible));
3501 }
3502 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3503 return CondRevAlloc(
3504 solver(), reversible,
3505 new TimesNegCstIntVarIterator(var_, cst_, false, reversible));
3506 }
3507 int64_t OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3508 int64_t OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3509};
3510
3511// ----- TimesNegCstIntVar -----
3512
3513TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c)
3514 : TimesCstIntVar(s, v, c) {}
3515
3516TimesNegCstIntVar::~TimesNegCstIntVar() {}
3517
3518int64_t TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3519
3520void TimesNegCstIntVar::SetMin(int64_t m) {
3521 if (m != std::numeric_limits<int64_t>::min()) {
3522 var_->SetMax(PosIntDivDown(-m, -cst_));
3523 }
3524}
3525
3526int64_t TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3527
3528void TimesNegCstIntVar::SetMax(int64_t m) {
3529 if (m != std::numeric_limits<int64_t>::max()) {
3530 var_->SetMin(PosIntDivUp(-m, -cst_));
3531 }
3532}
3533
3534void TimesNegCstIntVar::SetRange(int64_t l, int64_t u) {
3535 var_->SetRange(PosIntDivUp(CapOpp(u), CapOpp(cst_)),
3536 PosIntDivDown(CapOpp(l), CapOpp(cst_)));
3537}
3538
3539void TimesNegCstIntVar::SetValue(int64_t v) {
3540 if (v % cst_ != 0) {
3541 solver()->Fail();
3542 }
3543 var_->SetValue(v / cst_);
3544}
3545
3546bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3547
3548void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3549
3550int64_t TimesNegCstIntVar::Value() const {
3551 return CapProd(var_->Value(), cst_);
3552}
3553
3554void TimesNegCstIntVar::RemoveValue(int64_t v) {
3555 if (v % cst_ == 0) {
3556 var_->RemoveValue(v / cst_);
3557 }
3558}
3559
3560void TimesNegCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3561 for (int64_t v = l; v <= u; ++v) {
3562 RemoveValue(v);
3563 }
3564 // TODO(user) : Improve me
3565}
3566
3567void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3568
3569void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3570
3571uint64_t TimesNegCstIntVar::Size() const { return var_->Size(); }
3572
3573bool TimesNegCstIntVar::Contains(int64_t v) const {
3574 return (v % cst_ == 0 && var_->Contains(v / cst_));
3575}
3576
3577// ---------- arithmetic expressions ----------
3578
3579// ----- PlusIntExpr -----
3580
3581class PlusIntExpr : public BaseIntExpr {
3582 public:
3583 PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3584 : BaseIntExpr(s), left_(l), right_(r) {}
3585
3586 ~PlusIntExpr() override {}
3587
3588 int64_t Min() const override { return left_->Min() + right_->Min(); }
3589
3590 void SetMin(int64_t m) override {
3591 if (m > left_->Min() + right_->Min()) {
3592 // Catching potential overflow.
3593 if (m > right_->Max() + left_->Max()) solver()->Fail();
3594 left_->SetMin(m - right_->Max());
3595 right_->SetMin(m - left_->Max());
3596 }
3597 }
3598
3599 void SetRange(int64_t l, int64_t u) override {
3600 const int64_t left_min = left_->Min();
3601 const int64_t right_min = right_->Min();
3602 const int64_t left_max = left_->Max();
3603 const int64_t right_max = right_->Max();
3604 if (l > left_min + right_min) {
3605 // Catching potential overflow.
3606 if (l > right_max + left_max) solver()->Fail();
3607 left_->SetMin(l - right_max);
3608 right_->SetMin(l - left_max);
3609 }
3610 if (u < left_max + right_max) {
3611 // Catching potential overflow.
3612 if (u < right_min + left_min) solver()->Fail();
3613 left_->SetMax(u - right_min);
3614 right_->SetMax(u - left_min);
3615 }
3616 }
3617
3618 int64_t Max() const override { return left_->Max() + right_->Max(); }
3619
3620 void SetMax(int64_t m) override {
3621 if (m < left_->Max() + right_->Max()) {
3622 // Catching potential overflow.
3623 if (m < right_->Min() + left_->Min()) solver()->Fail();
3624 left_->SetMax(m - right_->Min());
3625 right_->SetMax(m - left_->Min());
3626 }
3627 }
3628
3629 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3630
3631 void Range(int64_t* const mi, int64_t* const ma) override {
3632 *mi = left_->Min() + right_->Min();
3633 *ma = left_->Max() + right_->Max();
3634 }
3635
3636 std::string name() const override {
3637 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3638 }
3639
3640 std::string DebugString() const override {
3641 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3642 right_->DebugString());
3643 }
3644
3645 void WhenRange(Demon* d) override {
3646 left_->WhenRange(d);
3647 right_->WhenRange(d);
3648 }
3649
3650 void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3651 PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3652 if (casted != nullptr) {
3653 ExpandPlusIntExpr(casted->left_, subs);
3654 ExpandPlusIntExpr(casted->right_, subs);
3655 } else {
3656 subs->push_back(expr);
3657 }
3658 }
3659
3660 IntVar* CastToVar() override {
3661 if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3662 dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3663 std::vector<IntExpr*> sub_exprs;
3664 ExpandPlusIntExpr(left_, &sub_exprs);
3665 ExpandPlusIntExpr(right_, &sub_exprs);
3666 if (sub_exprs.size() >= 3) {
3667 std::vector<IntVar*> sub_vars(sub_exprs.size());
3668 for (int i = 0; i < sub_exprs.size(); ++i) {
3669 sub_vars[i] = sub_exprs[i]->Var();
3670 }
3671 return solver()->MakeSum(sub_vars)->Var();
3672 }
3673 }
3674 return BaseIntExpr::CastToVar();
3675 }
3676
3677 void Accept(ModelVisitor* const visitor) const override {
3678 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3679 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3680 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3681 right_);
3682 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3683 }
3684
3685 private:
3686 IntExpr* const left_;
3687 IntExpr* const right_;
3688};
3689
3690class SafePlusIntExpr : public BaseIntExpr {
3691 public:
3692 SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3693 : BaseIntExpr(s), left_(l), right_(r) {}
3694
3695 ~SafePlusIntExpr() override {}
3696
3697 int64_t Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3698
3699 void SetMin(int64_t m) override {
3700 left_->SetMin(CapSub(m, right_->Max()));
3701 right_->SetMin(CapSub(m, left_->Max()));
3702 }
3703
3704 void SetRange(int64_t l, int64_t u) override {
3705 const int64_t left_min = left_->Min();
3706 const int64_t right_min = right_->Min();
3707 const int64_t left_max = left_->Max();
3708 const int64_t right_max = right_->Max();
3709 if (l > CapAdd(left_min, right_min)) {
3710 left_->SetMin(CapSub(l, right_max));
3711 right_->SetMin(CapSub(l, left_max));
3712 }
3713 if (u < CapAdd(left_max, right_max)) {
3714 left_->SetMax(CapSub(u, right_min));
3715 right_->SetMax(CapSub(u, left_min));
3716 }
3717 }
3718
3719 int64_t Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3720
3721 void SetMax(int64_t m) override {
3722 left_->SetMax(CapSub(m, right_->Min()));
3723 right_->SetMax(CapSub(m, left_->Min()));
3724 }
3725
3726 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3727
3728 std::string name() const override {
3729 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3730 }
3731
3732 std::string DebugString() const override {
3733 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3734 right_->DebugString());
3735 }
3736
3737 void WhenRange(Demon* d) override {
3738 left_->WhenRange(d);
3739 right_->WhenRange(d);
3740 }
3741
3742 void Accept(ModelVisitor* const visitor) const override {
3743 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3744 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3745 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3746 right_);
3747 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3748 }
3749
3750 private:
3751 IntExpr* const left_;
3752 IntExpr* const right_;
3753};
3754
3755// ----- PlusIntCstExpr -----
3756
3757class PlusIntCstExpr : public BaseIntExpr {
3758 public:
3759 PlusIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3760 : BaseIntExpr(s), expr_(e), value_(v) {}
3761 ~PlusIntCstExpr() override {}
3762 int64_t Min() const override { return CapAdd(expr_->Min(), value_); }
3763 void SetMin(int64_t m) override { expr_->SetMin(CapSub(m, value_)); }
3764 int64_t Max() const override { return CapAdd(expr_->Max(), value_); }
3765 void SetMax(int64_t m) override { expr_->SetMax(CapSub(m, value_)); }
3766 bool Bound() const override { return (expr_->Bound()); }
3767 std::string name() const override {
3768 return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3769 }
3770 std::string DebugString() const override {
3771 return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3772 }
3773 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3774 IntVar* CastToVar() override;
3775 void Accept(ModelVisitor* const visitor) const override {
3776 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3777 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3778 expr_);
3779 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3780 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3781 }
3782
3783 private:
3784 IntExpr* const expr_;
3785 const int64_t value_;
3786};
3787
3788IntVar* PlusIntCstExpr::CastToVar() {
3789 Solver* const s = solver();
3790 IntVar* const var = expr_->Var();
3791 IntVar* cast = nullptr;
3792 if (AddOverflows(value_, expr_->Max()) ||
3793 AddOverflows(value_, expr_->Min())) {
3794 return BaseIntExpr::CastToVar();
3795 }
3796 switch (var->VarType()) {
3797 case DOMAIN_INT_VAR:
3798 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3799 s, reinterpret_cast<DomainIntVar*>(var), value_)));
3800 // FIXME: Break was inserted during fallthrough cleanup. Please check.
3801 break;
3802 default:
3803 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3804 break;
3805 }
3806 return cast;
3807}
3808
3809// ----- SubIntExpr -----
3810
3811class SubIntExpr : public BaseIntExpr {
3812 public:
3813 SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3814 : BaseIntExpr(s), left_(l), right_(r) {}
3815
3816 ~SubIntExpr() override {}
3817
3818 int64_t Min() const override { return left_->Min() - right_->Max(); }
3819
3820 void SetMin(int64_t m) override {
3821 left_->SetMin(CapAdd(m, right_->Min()));
3822 right_->SetMax(CapSub(left_->Max(), m));
3823 }
3824
3825 int64_t Max() const override { return left_->Max() - right_->Min(); }
3826
3827 void SetMax(int64_t m) override {
3828 left_->SetMax(CapAdd(m, right_->Max()));
3829 right_->SetMin(CapSub(left_->Min(), m));
3830 }
3831
3832 void Range(int64_t* mi, int64_t* ma) override {
3833 *mi = left_->Min() - right_->Max();
3834 *ma = left_->Max() - right_->Min();
3835 }
3836
3837 void SetRange(int64_t l, int64_t u) override {
3838 const int64_t left_min = left_->Min();
3839 const int64_t right_min = right_->Min();
3840 const int64_t left_max = left_->Max();
3841 const int64_t right_max = right_->Max();
3842 if (l > left_min - right_max) {
3843 left_->SetMin(CapAdd(l, right_min));
3844 right_->SetMax(CapSub(left_max, l));
3845 }
3846 if (u < left_max - right_min) {
3847 left_->SetMax(CapAdd(u, right_max));
3848 right_->SetMin(CapSub(left_min, u));
3849 }
3850 }
3851
3852 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3853
3854 std::string name() const override {
3855 return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3856 }
3857
3858 std::string DebugString() const override {
3859 return absl::StrFormat("(%s - %s)", left_->DebugString(),
3860 right_->DebugString());
3861 }
3862
3863 void WhenRange(Demon* d) override {
3864 left_->WhenRange(d);
3865 right_->WhenRange(d);
3866 }
3867
3868 void Accept(ModelVisitor* const visitor) const override {
3869 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3870 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3871 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3872 right_);
3873 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3874 }
3875
3876 IntExpr* left() const { return left_; }
3877 IntExpr* right() const { return right_; }
3878
3879 protected:
3880 IntExpr* const left_;
3881 IntExpr* const right_;
3882};
3883
3884class SafeSubIntExpr : public SubIntExpr {
3885 public:
3886 SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3887 : SubIntExpr(s, l, r) {}
3888
3889 ~SafeSubIntExpr() override {}
3890
3891 int64_t Min() const override { return CapSub(left_->Min(), right_->Max()); }
3892
3893 void SetMin(int64_t m) override {
3894 left_->SetMin(CapAdd(m, right_->Min()));
3895 right_->SetMax(CapSub(left_->Max(), m));
3896 }
3897
3898 void SetRange(int64_t l, int64_t u) override {
3899 const int64_t left_min = left_->Min();
3900 const int64_t right_min = right_->Min();
3901 const int64_t left_max = left_->Max();
3902 const int64_t right_max = right_->Max();
3903 if (l > CapSub(left_min, right_max)) {
3904 left_->SetMin(CapAdd(l, right_min));
3905 right_->SetMax(CapSub(left_max, l));
3906 }
3907 if (u < CapSub(left_max, right_min)) {
3908 left_->SetMax(CapAdd(u, right_max));
3909 right_->SetMin(CapSub(left_min, u));
3910 }
3911 }
3912
3913 void Range(int64_t* mi, int64_t* ma) override {
3914 *mi = CapSub(left_->Min(), right_->Max());
3915 *ma = CapSub(left_->Max(), right_->Min());
3916 }
3917
3918 int64_t Max() const override { return CapSub(left_->Max(), right_->Min()); }
3919
3920 void SetMax(int64_t m) override {
3921 left_->SetMax(CapAdd(m, right_->Max()));
3922 right_->SetMin(CapSub(left_->Min(), m));
3923 }
3924};
3925
3926// l - r
3927
3928// ----- SubIntCstExpr -----
3929
3930class SubIntCstExpr : public BaseIntExpr {
3931 public:
3932 SubIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3933 : BaseIntExpr(s), expr_(e), value_(v) {}
3934 ~SubIntCstExpr() override {}
3935 int64_t Min() const override { return CapSub(value_, expr_->Max()); }
3936 void SetMin(int64_t m) override { expr_->SetMax(CapSub(value_, m)); }
3937 int64_t Max() const override { return CapSub(value_, expr_->Min()); }
3938 void SetMax(int64_t m) override { expr_->SetMin(CapSub(value_, m)); }
3939 bool Bound() const override { return (expr_->Bound()); }
3940 std::string name() const override {
3941 return absl::StrFormat("(%d - %s)", value_, expr_->name());
3942 }
3943 std::string DebugString() const override {
3944 return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3945 }
3946 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3947 IntVar* CastToVar() override;
3948
3949 void Accept(ModelVisitor* const visitor) const override {
3950 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3951 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3952 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3953 expr_);
3954 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3955 }
3956
3957 private:
3958 IntExpr* const expr_;
3959 const int64_t value_;
3960};
3961
3962IntVar* SubIntCstExpr::CastToVar() {
3963 if (SubOverflows(value_, expr_->Min()) ||
3964 SubOverflows(value_, expr_->Max())) {
3965 return BaseIntExpr::CastToVar();
3966 }
3967 Solver* const s = solver();
3968 IntVar* const var =
3969 s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3970 return var;
3971}
3972
3973// ----- OppIntExpr -----
3974
3975class OppIntExpr : public BaseIntExpr {
3976 public:
3977 OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3978 ~OppIntExpr() override {}
3979 int64_t Min() const override { return (CapOpp(expr_->Max())); }
3980 void SetMin(int64_t m) override { expr_->SetMax(CapOpp(m)); }
3981 int64_t Max() const override { return (CapOpp(expr_->Min())); }
3982 void SetMax(int64_t m) override { expr_->SetMin(CapOpp(m)); }
3983 bool Bound() const override { return (expr_->Bound()); }
3984 std::string name() const override {
3985 return absl::StrFormat("(-%s)", expr_->name());
3986 }
3987 std::string DebugString() const override {
3988 return absl::StrFormat("(-%s)", expr_->DebugString());
3989 }
3990 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3991 IntVar* CastToVar() override;
3992
3993 void Accept(ModelVisitor* const visitor) const override {
3994 visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3995 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3996 expr_);
3997 visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3998 }
3999
4000 private:
4001 IntExpr* const expr_;
4002};
4003
4004IntVar* OppIntExpr::CastToVar() {
4005 Solver* const s = solver();
4006 IntVar* const var =
4007 s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
4008 return var;
4009}
4010
4011// ----- TimesIntCstExpr -----
4012
4013class TimesIntCstExpr : public BaseIntExpr {
4014 public:
4015 TimesIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4016 : BaseIntExpr(s), expr_(e), value_(v) {}
4017
4018 ~TimesIntCstExpr() override {}
4019
4020 bool Bound() const override { return (expr_->Bound()); }
4021
4022 std::string name() const override {
4023 return absl::StrFormat("(%s * %d)", expr_->name(), value_);
4024 }
4025
4026 std::string DebugString() const override {
4027 return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
4028 }
4029
4030 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4031
4032 IntExpr* Expr() const { return expr_; }
4033
4034 int64_t Constant() const { return value_; }
4035
4036 void Accept(ModelVisitor* const visitor) const override {
4037 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4038 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4039 expr_);
4040 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4041 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4042 }
4043
4044 protected:
4045 IntExpr* const expr_;
4046 const int64_t value_;
4047};
4048
4049// ----- TimesPosIntCstExpr -----
4050
4051class TimesPosIntCstExpr : public TimesIntCstExpr {
4052 public:
4053 TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4054 : TimesIntCstExpr(s, e, v) {
4055 CHECK_GT(v, 0);
4056 }
4057
4058 ~TimesPosIntCstExpr() override {}
4059
4060 int64_t Min() const override { return expr_->Min() * value_; }
4061
4062 void SetMin(int64_t m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4063
4064 int64_t Max() const override { return expr_->Max() * value_; }
4065
4066 void SetMax(int64_t m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4067
4068 IntVar* CastToVar() override {
4069 Solver* const s = solver();
4070 IntVar* var = nullptr;
4071 if (expr_->IsVar() &&
4072 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4073 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4074 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4075 } else {
4076 var = s->RegisterIntVar(
4077 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4078 }
4079 return var;
4080 }
4081};
4082
4083// This expressions adds safe arithmetic (w.r.t. overflows) compared
4084// to the previous one.
4085class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4086 public:
4087 SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4088 : TimesIntCstExpr(s, e, v) {
4089 CHECK_GT(v, 0);
4090 }
4091
4092 ~SafeTimesPosIntCstExpr() override {}
4093
4094 int64_t Min() const override { return CapProd(expr_->Min(), value_); }
4095
4096 void SetMin(int64_t m) override {
4097 if (m != std::numeric_limits<int64_t>::min()) {
4098 expr_->SetMin(PosIntDivUp(m, value_));
4099 }
4100 }
4101
4102 int64_t Max() const override { return CapProd(expr_->Max(), value_); }
4103
4104 void SetMax(int64_t m) override {
4105 if (m != std::numeric_limits<int64_t>::max()) {
4106 expr_->SetMax(PosIntDivDown(m, value_));
4107 }
4108 }
4109
4110 IntVar* CastToVar() override {
4111 Solver* const s = solver();
4112 IntVar* var = nullptr;
4113 if (expr_->IsVar() &&
4114 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4115 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4116 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4117 } else {
4118 // TODO(user): Check overflows.
4119 var = s->RegisterIntVar(
4120 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4121 }
4122 return var;
4123 }
4124};
4125
4126// ----- TimesIntNegCstExpr -----
4127
4128class TimesIntNegCstExpr : public TimesIntCstExpr {
4129 public:
4130 TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4131 : TimesIntCstExpr(s, e, v) {
4132 CHECK_LT(v, 0);
4133 }
4134
4135 ~TimesIntNegCstExpr() override {}
4136
4137 int64_t Min() const override { return CapProd(expr_->Max(), value_); }
4138
4139 void SetMin(int64_t m) override {
4140 if (m != std::numeric_limits<int64_t>::min()) {
4141 expr_->SetMax(PosIntDivDown(-m, -value_));
4142 }
4143 }
4144
4145 int64_t Max() const override { return CapProd(expr_->Min(), value_); }
4146
4147 void SetMax(int64_t m) override {
4148 if (m != std::numeric_limits<int64_t>::max()) {
4149 expr_->SetMin(PosIntDivUp(-m, -value_));
4150 }
4151 }
4152
4153 IntVar* CastToVar() override {
4154 Solver* const s = solver();
4155 IntVar* var = nullptr;
4156 var = s->RegisterIntVar(
4157 s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4158 return var;
4159 }
4160};
4161
4162// ----- Utilities for product expression -----
4163
4164// Propagates set_min on left * right, left and right >= 0.
4165void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4166 DCHECK_GE(left->Min(), 0);
4167 DCHECK_GE(right->Min(), 0);
4168 const int64_t lmax = left->Max();
4169 const int64_t rmax = right->Max();
4170 if (m > CapProd(lmax, rmax)) {
4171 left->solver()->Fail();
4172 }
4173 if (m > CapProd(left->Min(), right->Min())) {
4174 // Ok for m == 0 due to left and right being positive
4175 if (0 != rmax) {
4176 left->SetMin(PosIntDivUp(m, rmax));
4177 }
4178 if (0 != lmax) {
4179 right->SetMin(PosIntDivUp(m, lmax));
4180 }
4181 }
4182}
4183
4184// Propagates set_max on left * right, left and right >= 0.
4185void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4186 DCHECK_GE(left->Min(), 0);
4187 DCHECK_GE(right->Min(), 0);
4188 const int64_t lmin = left->Min();
4189 const int64_t rmin = right->Min();
4190 if (m < CapProd(lmin, rmin)) {
4191 left->solver()->Fail();
4192 }
4193 if (m < CapProd(left->Max(), right->Max())) {
4194 if (0 != lmin) {
4195 right->SetMax(PosIntDivDown(m, lmin));
4196 }
4197 if (0 != rmin) {
4198 left->SetMax(PosIntDivDown(m, rmin));
4199 }
4200 // else do nothing: 0 is supporting any value from other expr.
4201 }
4202}
4203
4204// Propagates set_min on left * right, left >= 0, right across 0.
4205void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4206 DCHECK_GE(left->Min(), 0);
4207 DCHECK_GT(right->Max(), 0);
4208 DCHECK_LT(right->Min(), 0);
4209 const int64_t lmax = left->Max();
4210 const int64_t rmax = right->Max();
4211 if (m > CapProd(lmax, rmax)) {
4212 left->solver()->Fail();
4213 }
4214 if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4215 DCHECK_EQ(0, left->Min());
4216 DCHECK_LE(m, 0);
4217 } else {
4218 if (m > 0) { // We deduce right > 0.
4219 left->SetMin(PosIntDivUp(m, rmax));
4220 right->SetMin(PosIntDivUp(m, lmax));
4221 } else if (m == 0) {
4222 const int64_t lmin = left->Min();
4223 if (lmin > 0) {
4224 right->SetMin(0);
4225 }
4226 } else { // m < 0
4227 const int64_t lmin = left->Min();
4228 if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4229 right->SetMin(-PosIntDivDown(-m, lmin));
4230 }
4231 }
4232 }
4233}
4234
4235// Propagates set_min on left * right, left and right across 0.
4236void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4237 DCHECK_LT(left->Min(), 0);
4238 DCHECK_GT(left->Max(), 0);
4239 DCHECK_GT(right->Max(), 0);
4240 DCHECK_LT(right->Min(), 0);
4241 const int64_t lmin = left->Min();
4242 const int64_t lmax = left->Max();
4243 const int64_t rmin = right->Min();
4244 const int64_t rmax = right->Max();
4245 if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4246 left->solver()->Fail();
4247 }
4248 if (m >
4249 CapProd(lmin, rmin)) { // Must be positive section * positive section.
4250 left->SetMin(PosIntDivUp(m, rmax));
4251 right->SetMin(PosIntDivUp(m, lmax));
4252 } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4253 left->SetMax(CapOpp(PosIntDivUp(m, CapOpp(rmin))));
4254 right->SetMax(CapOpp(PosIntDivUp(m, CapOpp(lmin))));
4255 }
4256}
4257
4258void TimesSetMin(IntExpr* const left, IntExpr* const right,
4259 IntExpr* const minus_left, IntExpr* const minus_right,
4260 int64_t m) {
4261 if (left->Min() >= 0) {
4262 if (right->Min() >= 0) {
4263 SetPosPosMinExpr(left, right, m);
4264 } else if (right->Max() <= 0) {
4265 SetPosPosMaxExpr(left, minus_right, -m);
4266 } else { // right->Min() < 0 && right->Max() > 0
4267 SetPosGenMinExpr(left, right, m);
4268 }
4269 } else if (left->Max() <= 0) {
4270 if (right->Min() >= 0) {
4271 SetPosPosMaxExpr(right, minus_left, -m);
4272 } else if (right->Max() <= 0) {
4273 SetPosPosMinExpr(minus_left, minus_right, m);
4274 } else { // right->Min() < 0 && right->Max() > 0
4275 SetPosGenMinExpr(minus_left, minus_right, m);
4276 }
4277 } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4278 SetPosGenMinExpr(right, left, m);
4279 } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4280 SetPosGenMinExpr(minus_right, minus_left, m);
4281 } else { // left->Min() < 0 && left->Max() > 0 &&
4282 // right->Min() < 0 && right->Max() > 0
4283 SetGenGenMinExpr(left, right, m);
4284 }
4285}
4286
4287class TimesIntExpr : public BaseIntExpr {
4288 public:
4289 TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4290 : BaseIntExpr(s),
4291 left_(l),
4292 right_(r),
4293 minus_left_(s->MakeOpposite(left_)),
4294 minus_right_(s->MakeOpposite(right_)) {}
4295 ~TimesIntExpr() override {}
4296 int64_t Min() const override {
4297 const int64_t lmin = left_->Min();
4298 const int64_t lmax = left_->Max();
4299 const int64_t rmin = right_->Min();
4300 const int64_t rmax = right_->Max();
4301 return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4302 std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4303 }
4304 void SetMin(int64_t m) override;
4305 int64_t Max() const override {
4306 const int64_t lmin = left_->Min();
4307 const int64_t lmax = left_->Max();
4308 const int64_t rmin = right_->Min();
4309 const int64_t rmax = right_->Max();
4310 return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4311 std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4312 }
4313 void SetMax(int64_t m) override;
4314 bool Bound() const override;
4315 std::string name() const override {
4316 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4317 }
4318 std::string DebugString() const override {
4319 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4320 right_->DebugString());
4321 }
4322 void WhenRange(Demon* d) override {
4323 left_->WhenRange(d);
4324 right_->WhenRange(d);
4325 }
4326
4327 void Accept(ModelVisitor* const visitor) const override {
4328 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4329 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4330 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4331 right_);
4332 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4333 }
4334
4335 private:
4336 IntExpr* const left_;
4337 IntExpr* const right_;
4338 IntExpr* const minus_left_;
4339 IntExpr* const minus_right_;
4340};
4341
4342void TimesIntExpr::SetMin(int64_t m) {
4343 if (m != std::numeric_limits<int64_t>::min()) {
4344 TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4345 }
4346}
4347
4348void TimesIntExpr::SetMax(int64_t m) {
4349 if (m != std::numeric_limits<int64_t>::max()) {
4350 TimesSetMin(left_, minus_right_, minus_left_, right_, CapOpp(m));
4351 }
4352}
4353
4354bool TimesIntExpr::Bound() const {
4355 const bool left_bound = left_->Bound();
4356 const bool right_bound = right_->Bound();
4357 return ((left_bound && left_->Max() == 0) ||
4358 (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4359}
4360
4361// ----- TimesPosIntExpr -----
4362
4363class TimesPosIntExpr : public BaseIntExpr {
4364 public:
4365 TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4366 : BaseIntExpr(s), left_(l), right_(r) {}
4367 ~TimesPosIntExpr() override {}
4368 int64_t Min() const override { return (left_->Min() * right_->Min()); }
4369 void SetMin(int64_t m) override;
4370 int64_t Max() const override { return (left_->Max() * right_->Max()); }
4371 void SetMax(int64_t m) override;
4372 bool Bound() const override;
4373 std::string name() const override {
4374 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4375 }
4376 std::string DebugString() const override {
4377 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4378 right_->DebugString());
4379 }
4380 void WhenRange(Demon* d) override {
4381 left_->WhenRange(d);
4382 right_->WhenRange(d);
4383 }
4384
4385 void Accept(ModelVisitor* const visitor) const override {
4386 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4387 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4388 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4389 right_);
4390 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4391 }
4392
4393 private:
4394 IntExpr* const left_;
4395 IntExpr* const right_;
4396};
4397
4398void TimesPosIntExpr::SetMin(int64_t m) { SetPosPosMinExpr(left_, right_, m); }
4399
4400void TimesPosIntExpr::SetMax(int64_t m) { SetPosPosMaxExpr(left_, right_, m); }
4401
4402bool TimesPosIntExpr::Bound() const {
4403 return (left_->Max() == 0 || right_->Max() == 0 ||
4404 (left_->Bound() && right_->Bound()));
4405}
4406
4407// ----- SafeTimesPosIntExpr -----
4408
4409class SafeTimesPosIntExpr : public BaseIntExpr {
4410 public:
4411 SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4412 : BaseIntExpr(s), left_(l), right_(r) {}
4413 ~SafeTimesPosIntExpr() override {}
4414 int64_t Min() const override { return CapProd(left_->Min(), right_->Min()); }
4415 void SetMin(int64_t m) override {
4416 if (m != std::numeric_limits<int64_t>::min()) {
4417 SetPosPosMinExpr(left_, right_, m);
4418 }
4419 }
4420 int64_t Max() const override { return CapProd(left_->Max(), right_->Max()); }
4421 void SetMax(int64_t m) override {
4422 if (m != std::numeric_limits<int64_t>::max()) {
4423 SetPosPosMaxExpr(left_, right_, m);
4424 }
4425 }
4426 bool Bound() const override {
4427 return (left_->Max() == 0 || right_->Max() == 0 ||
4428 (left_->Bound() && right_->Bound()));
4429 }
4430 std::string name() const override {
4431 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4432 }
4433 std::string DebugString() const override {
4434 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4435 right_->DebugString());
4436 }
4437 void WhenRange(Demon* d) override {
4438 left_->WhenRange(d);
4439 right_->WhenRange(d);
4440 }
4441
4442 void Accept(ModelVisitor* const visitor) const override {
4443 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4444 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4445 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4446 right_);
4447 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4448 }
4449
4450 private:
4451 IntExpr* const left_;
4452 IntExpr* const right_;
4453};
4454
4455// ----- TimesBooleanPosIntExpr -----
4456
4457class TimesBooleanPosIntExpr : public BaseIntExpr {
4458 public:
4459 TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4460 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4461 ~TimesBooleanPosIntExpr() override {}
4462 int64_t Min() const override {
4463 return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4464 }
4465 void SetMin(int64_t m) override;
4466 int64_t Max() const override {
4467 return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4468 }
4469 void SetMax(int64_t m) override;
4470 void Range(int64_t* mi, int64_t* ma) override;
4471 void SetRange(int64_t mi, int64_t ma) override;
4472 bool Bound() const override;
4473 std::string name() const override {
4474 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4475 }
4476 std::string DebugString() const override {
4477 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4478 expr_->DebugString());
4479 }
4480 void WhenRange(Demon* d) override {
4481 boolvar_->WhenRange(d);
4482 expr_->WhenRange(d);
4483 }
4484
4485 void Accept(ModelVisitor* const visitor) const override {
4486 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4487 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4488 boolvar_);
4489 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4490 expr_);
4491 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4492 }
4493
4494 private:
4495 BooleanVar* const boolvar_;
4496 IntExpr* const expr_;
4497};
4498
4499void TimesBooleanPosIntExpr::SetMin(int64_t m) {
4500 if (m > 0) {
4501 boolvar_->SetValue(1);
4502 expr_->SetMin(m);
4503 }
4504}
4505
4506void TimesBooleanPosIntExpr::SetMax(int64_t m) {
4507 if (m < 0) {
4508 solver()->Fail();
4509 }
4510 if (m < expr_->Min()) {
4511 boolvar_->SetValue(0);
4512 }
4513 if (boolvar_->RawValue() == 1) {
4514 expr_->SetMax(m);
4515 }
4516}
4517
4518void TimesBooleanPosIntExpr::Range(int64_t* mi, int64_t* ma) {
4519 const int value = boolvar_->RawValue();
4520 if (value == 0) {
4521 *mi = 0;
4522 *ma = 0;
4523 } else if (value == 1) {
4524 expr_->Range(mi, ma);
4525 } else {
4526 *mi = 0;
4527 *ma = expr_->Max();
4528 }
4529}
4530
4531void TimesBooleanPosIntExpr::SetRange(int64_t mi, int64_t ma) {
4532 if (ma < 0 || mi > ma) {
4533 solver()->Fail();
4534 }
4535 if (mi > 0) {
4536 boolvar_->SetValue(1);
4537 expr_->SetMin(mi);
4538 }
4539 if (ma < expr_->Min()) {
4540 boolvar_->SetValue(0);
4541 }
4542 if (boolvar_->RawValue() == 1) {
4543 expr_->SetMax(ma);
4544 }
4545}
4546
4547bool TimesBooleanPosIntExpr::Bound() const {
4548 return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4549 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4550 expr_->Bound()));
4551}
4552
4553// ----- TimesBooleanIntExpr -----
4554
4555class TimesBooleanIntExpr : public BaseIntExpr {
4556 public:
4557 TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4558 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4559 ~TimesBooleanIntExpr() override {}
4560 int64_t Min() const override {
4561 switch (boolvar_->RawValue()) {
4562 case 0: {
4563 return 0LL;
4564 }
4565 case 1: {
4566 return expr_->Min();
4567 }
4568 default: {
4569 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4570 return std::min(int64_t{0}, expr_->Min());
4571 }
4572 }
4573 }
4574 void SetMin(int64_t m) override;
4575 int64_t Max() const override {
4576 switch (boolvar_->RawValue()) {
4577 case 0: {
4578 return 0LL;
4579 }
4580 case 1: {
4581 return expr_->Max();
4582 }
4583 default: {
4584 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4585 return std::max(int64_t{0}, expr_->Max());
4586 }
4587 }
4588 }
4589 void SetMax(int64_t m) override;
4590 void Range(int64_t* mi, int64_t* ma) override;
4591 void SetRange(int64_t mi, int64_t ma) override;
4592 bool Bound() const override;
4593 std::string name() const override {
4594 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4595 }
4596 std::string DebugString() const override {
4597 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4598 expr_->DebugString());
4599 }
4600 void WhenRange(Demon* d) override {
4601 boolvar_->WhenRange(d);
4602 expr_->WhenRange(d);
4603 }
4604
4605 void Accept(ModelVisitor* const visitor) const override {
4606 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4607 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4608 boolvar_);
4609 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4610 expr_);
4611 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4612 }
4613
4614 private:
4615 BooleanVar* const boolvar_;
4616 IntExpr* const expr_;
4617};
4618
4619void TimesBooleanIntExpr::SetMin(int64_t m) {
4620 switch (boolvar_->RawValue()) {
4621 case 0: {
4622 if (m > 0) {
4623 solver()->Fail();
4624 }
4625 break;
4626 }
4627 case 1: {
4628 expr_->SetMin(m);
4629 break;
4630 }
4631 default: {
4632 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4633 if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4634 boolvar_->SetValue(1);
4635 expr_->SetMin(m);
4636 } else if (m <= 0 && expr_->Max() < m) {
4637 boolvar_->SetValue(0);
4638 }
4639 }
4640 }
4641}
4642
4643void TimesBooleanIntExpr::SetMax(int64_t m) {
4644 switch (boolvar_->RawValue()) {
4645 case 0: {
4646 if (m < 0) {
4647 solver()->Fail();
4648 }
4649 break;
4650 }
4651 case 1: {
4652 expr_->SetMax(m);
4653 break;
4654 }
4655 default: {
4656 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4657 if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4658 boolvar_->SetValue(1);
4659 expr_->SetMax(m);
4660 } else if (m >= 0 && expr_->Min() > m) {
4661 boolvar_->SetValue(0);
4662 }
4663 }
4664 }
4665}
4666
4667void TimesBooleanIntExpr::Range(int64_t* mi, int64_t* ma) {
4668 switch (boolvar_->RawValue()) {
4669 case 0: {
4670 *mi = 0;
4671 *ma = 0;
4672 break;
4673 }
4674 case 1: {
4675 *mi = expr_->Min();
4676 *ma = expr_->Max();
4677 break;
4678 }
4679 default: {
4680 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4681 *mi = std::min(int64_t{0}, expr_->Min());
4682 *ma = std::max(int64_t{0}, expr_->Max());
4683 break;
4684 }
4685 }
4686}
4687
4688void TimesBooleanIntExpr::SetRange(int64_t mi, int64_t ma) {
4689 if (mi > ma) {
4690 solver()->Fail();
4691 }
4692 switch (boolvar_->RawValue()) {
4693 case 0: {
4694 if (mi > 0 || ma < 0) {
4695 solver()->Fail();
4696 }
4697 break;
4698 }
4699 case 1: {
4700 expr_->SetRange(mi, ma);
4701 break;
4702 }
4703 default: {
4704 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4705 if (mi > 0) {
4706 boolvar_->SetValue(1);
4707 expr_->SetMin(mi);
4708 } else if (mi == 0 && expr_->Max() < 0) {
4709 boolvar_->SetValue(0);
4710 }
4711 if (ma < 0) {
4712 boolvar_->SetValue(1);
4713 expr_->SetMax(ma);
4714 } else if (ma == 0 && expr_->Min() > 0) {
4715 boolvar_->SetValue(0);
4716 }
4717 break;
4718 }
4719 }
4720}
4721
4722bool TimesBooleanIntExpr::Bound() const {
4723 return (boolvar_->RawValue() == 0 ||
4724 (expr_->Bound() &&
4725 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4726 expr_->Max() == 0)));
4727}
4728
4729// ----- DivPosIntCstExpr -----
4730
4731class DivPosIntCstExpr : public BaseIntExpr {
4732 public:
4733 DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4734 : BaseIntExpr(s), expr_(e), value_(v) {
4735 CHECK_GE(v, 0);
4736 }
4737 ~DivPosIntCstExpr() override {}
4738
4739 int64_t Min() const override { return expr_->Min() / value_; }
4740
4741 void SetMin(int64_t m) override {
4742 if (m > 0) {
4743 expr_->SetMin(m * value_);
4744 } else {
4745 expr_->SetMin((m - 1) * value_ + 1);
4746 }
4747 }
4748 int64_t Max() const override { return expr_->Max() / value_; }
4749
4750 void SetMax(int64_t m) override {
4751 if (m >= 0) {
4752 expr_->SetMax((m + 1) * value_ - 1);
4753 } else {
4754 expr_->SetMax(m * value_);
4755 }
4756 }
4757
4758 std::string name() const override {
4759 return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4760 }
4761
4762 std::string DebugString() const override {
4763 return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4764 }
4765
4766 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4767
4768 void Accept(ModelVisitor* const visitor) const override {
4769 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4770 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4771 expr_);
4772 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4773 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4774 }
4775
4776 private:
4777 IntExpr* const expr_;
4778 const int64_t value_;
4779};
4780
4781// DivPosIntExpr
4782
4783class DivPosIntExpr : public BaseIntExpr {
4784 public:
4785 DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4786 : BaseIntExpr(s),
4787 num_(num),
4788 denom_(denom),
4789 opp_num_(s->MakeOpposite(num)) {}
4790
4791 ~DivPosIntExpr() override {}
4792
4793 int64_t Min() const override {
4794 return num_->Min() >= 0
4795 ? num_->Min() / denom_->Max()
4796 : (denom_->Min() == 0 ? num_->Min()
4797 : num_->Min() / denom_->Min());
4798 }
4799
4800 int64_t Max() const override {
4801 return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4802 : num_->Max() / denom_->Min())
4803 : num_->Max() / denom_->Max();
4804 }
4805
4806 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4807 num->SetMin(m * denom->Min());
4808 denom->SetMax(num->Max() / m);
4809 }
4810
4811 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
4812 num->SetMax((m + 1) * denom->Max() - 1);
4813 denom->SetMin(num->Min() / (m + 1) + 1);
4814 }
4815
4816 void SetMin(int64_t m) override {
4817 if (m > 0) {
4818 SetPosMin(num_, denom_, m);
4819 } else {
4820 SetPosMax(opp_num_, denom_, -m);
4821 }
4822 }
4823
4824 void SetMax(int64_t m) override {
4825 if (m >= 0) {
4826 SetPosMax(num_, denom_, m);
4827 } else {
4828 SetPosMin(opp_num_, denom_, -m);
4829 }
4830 }
4831
4832 std::string name() const override {
4833 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4834 }
4835 std::string DebugString() const override {
4836 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4837 denom_->DebugString());
4838 }
4839 void WhenRange(Demon* d) override {
4840 num_->WhenRange(d);
4841 denom_->WhenRange(d);
4842 }
4843
4844 void Accept(ModelVisitor* const visitor) const override {
4845 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4846 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4847 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4848 denom_);
4849 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4850 }
4851
4852 private:
4853 IntExpr* const num_;
4854 IntExpr* const denom_;
4855 IntExpr* const opp_num_;
4856};
4857
4858class DivPosPosIntExpr : public BaseIntExpr {
4859 public:
4860 DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4861 : BaseIntExpr(s), num_(num), denom_(denom) {}
4862
4863 ~DivPosPosIntExpr() override {}
4864
4865 int64_t Min() const override {
4866 if (denom_->Max() == 0) {
4867 solver()->Fail();
4868 }
4869 return num_->Min() / denom_->Max();
4870 }
4871
4872 int64_t Max() const override {
4873 if (denom_->Min() == 0) {
4874 return num_->Max();
4875 } else {
4876 return num_->Max() / denom_->Min();
4877 }
4878 }
4879
4880 void SetMin(int64_t m) override {
4881 if (m > 0) {
4882 num_->SetMin(m * denom_->Min());
4883 denom_->SetMax(num_->Max() / m);
4884 }
4885 }
4886
4887 void SetMax(int64_t m) override {
4888 if (m >= 0) {
4889 num_->SetMax((m + 1) * denom_->Max() - 1);
4890 denom_->SetMin(num_->Min() / (m + 1) + 1);
4891 } else {
4892 solver()->Fail();
4893 }
4894 }
4895
4896 std::string name() const override {
4897 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4898 }
4899
4900 std::string DebugString() const override {
4901 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4902 denom_->DebugString());
4903 }
4904
4905 void WhenRange(Demon* d) override {
4906 num_->WhenRange(d);
4907 denom_->WhenRange(d);
4908 }
4909
4910 void Accept(ModelVisitor* const visitor) const override {
4911 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4912 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4913 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4914 denom_);
4915 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4916 }
4917
4918 private:
4919 IntExpr* const num_;
4920 IntExpr* const denom_;
4921};
4922
4923// DivIntExpr
4924
4925class DivIntExpr : public BaseIntExpr {
4926 public:
4927 DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4928 : BaseIntExpr(s),
4929 num_(num),
4930 denom_(denom),
4931 opp_num_(s->MakeOpposite(num)) {}
4932
4933 ~DivIntExpr() override {}
4934
4935 int64_t Min() const override {
4936 const int64_t num_min = num_->Min();
4937 const int64_t num_max = num_->Max();
4938 const int64_t denom_min = denom_->Min();
4939 const int64_t denom_max = denom_->Max();
4940
4941 if (denom_min == 0 && denom_max == 0) {
4942 return std::numeric_limits<int64_t>::max(); // TODO(user): Check this
4943 // convention.
4944 }
4945
4946 if (denom_min >= 0) { // Denominator strictly positive.
4947 DCHECK_GT(denom_max, 0);
4948 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4949 return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4950 } else if (denom_max <= 0) { // Denominator strictly negative.
4951 DCHECK_LT(denom_min, 0);
4952 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4953 return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4954 } else { // Denominator across 0.
4955 return std::min(num_min, -num_max);
4956 }
4957 }
4958
4959 int64_t Max() const override {
4960 const int64_t num_min = num_->Min();
4961 const int64_t num_max = num_->Max();
4962 const int64_t denom_min = denom_->Min();
4963 const int64_t denom_max = denom_->Max();
4964
4965 if (denom_min == 0 && denom_max == 0) {
4966 return std::numeric_limits<int64_t>::min(); // TODO(user): Check this
4967 // convention.
4968 }
4969
4970 if (denom_min >= 0) { // Denominator strictly positive.
4971 DCHECK_GT(denom_max, 0);
4972 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4973 return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4974 } else if (denom_max <= 0) { // Denominator strictly negative.
4975 DCHECK_LT(denom_min, 0);
4976 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4977 return num_min >= 0 ? num_min / denom_min
4978 : -num_min / -adjusted_denom_max;
4979 } else { // Denominator across 0.
4980 return std::max(num_max, -num_min);
4981 }
4982 }
4983
4984 void AdjustDenominator() {
4985 if (denom_->Min() == 0) {
4986 denom_->SetMin(1);
4987 } else if (denom_->Max() == 0) {
4988 denom_->SetMax(-1);
4989 }
4990 }
4991
4992 // m > 0.
4993 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4994 DCHECK_GT(m, 0);
4995 const int64_t num_min = num->Min();
4996 const int64_t num_max = num->Max();
4997 const int64_t denom_min = denom->Min();
4998 const int64_t denom_max = denom->Max();
4999 DCHECK_NE(denom_min, 0);
5000 DCHECK_NE(denom_max, 0);
5001 if (denom_min > 0) { // Denominator strictly positive.
5002 num->SetMin(m * denom_min);
5003 denom->SetMax(num_max / m);
5004 } else if (denom_max < 0) { // Denominator strictly negative.
5005 num->SetMax(m * denom_max);
5006 denom->SetMin(num_min / m);
5007 } else { // Denominator across 0.
5008 if (num_min >= 0) {
5009 num->SetMin(m);
5010 denom->SetRange(1, num_max / m);
5011 } else if (num_max <= 0) {
5012 num->SetMax(-m);
5013 denom->SetRange(num_min / m, -1);
5014 } else {
5015 if (m > -num_min) { // Denominator is forced positive.
5016 num->SetMin(m);
5017 denom->SetRange(1, num_max / m);
5018 } else if (m > num_max) { // Denominator is forced negative.
5019 num->SetMax(-m);
5020 denom->SetRange(num_min / m, -1);
5021 } else {
5022 denom->SetRange(num_min / m, num_max / m);
5023 }
5024 }
5025 }
5026 }
5027
5028 // m >= 0.
5029 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
5030 DCHECK_GE(m, 0);
5031 const int64_t num_min = num->Min();
5032 const int64_t num_max = num->Max();
5033 const int64_t denom_min = denom->Min();
5034 const int64_t denom_max = denom->Max();
5035 DCHECK_NE(denom_min, 0);
5036 DCHECK_NE(denom_max, 0);
5037 if (denom_min > 0) { // Denominator strictly positive.
5038 num->SetMax((m + 1) * denom_max - 1);
5039 denom->SetMin((num_min / (m + 1)) + 1);
5040 } else if (denom_max < 0) {
5041 num->SetMin((m + 1) * denom_min + 1);
5042 denom->SetMax(num_max / (m + 1) - 1);
5043 } else if (num_min > (m + 1) * denom_max - 1) {
5044 denom->SetMax(-1);
5045 } else if (num_max < (m + 1) * denom_min + 1) {
5046 denom->SetMin(1);
5047 }
5048 }
5049
5050 void SetMin(int64_t m) override {
5051 AdjustDenominator();
5052 if (m > 0) {
5053 SetPosMin(num_, denom_, m);
5054 } else {
5055 SetPosMax(opp_num_, denom_, -m);
5056 }
5057 }
5058
5059 void SetMax(int64_t m) override {
5060 AdjustDenominator();
5061 if (m >= 0) {
5062 SetPosMax(num_, denom_, m);
5063 } else {
5064 SetPosMin(opp_num_, denom_, -m);
5065 }
5066 }
5067
5068 std::string name() const override {
5069 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5070 }
5071 std::string DebugString() const override {
5072 return absl::StrFormat("(%s div %s)", num_->DebugString(),
5073 denom_->DebugString());
5074 }
5075 void WhenRange(Demon* d) override {
5076 num_->WhenRange(d);
5077 denom_->WhenRange(d);
5078 }
5079
5080 void Accept(ModelVisitor* const visitor) const override {
5081 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5082 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5083 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5084 denom_);
5085 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5086 }
5087
5088 private:
5089 IntExpr* const num_;
5090 IntExpr* const denom_;
5091 IntExpr* const opp_num_;
5092};
5093
5094// ----- IntAbs And IntAbsConstraint ------
5095
5096class IntAbsConstraint : public CastConstraint {
5097 public:
5098 IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5099 : CastConstraint(s, target), sub_(sub) {}
5100
5101 ~IntAbsConstraint() override {}
5102
5103 void Post() override {
5104 Demon* const sub_demon = MakeConstraintDemon0(
5105 solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5106 sub_->WhenRange(sub_demon);
5107 Demon* const target_demon = MakeConstraintDemon0(
5108 solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5109 target_var_->WhenRange(target_demon);
5110 }
5111
5112 void InitialPropagate() override {
5113 PropagateSub();
5114 PropagateTarget();
5115 }
5116
5117 void PropagateSub() {
5118 const int64_t smin = sub_->Min();
5119 const int64_t smax = sub_->Max();
5120 if (smax <= 0) {
5121 target_var_->SetRange(-smax, -smin);
5122 } else if (smin >= 0) {
5123 target_var_->SetRange(smin, smax);
5124 } else {
5125 target_var_->SetRange(0, std::max(-smin, smax));
5126 }
5127 }
5128
5129 void PropagateTarget() {
5130 const int64_t target_max = target_var_->Max();
5131 sub_->SetRange(-target_max, target_max);
5132 const int64_t target_min = target_var_->Min();
5133 if (target_min > 0) {
5134 if (sub_->Min() > -target_min) {
5135 sub_->SetMin(target_min);
5136 } else if (sub_->Max() < target_min) {
5137 sub_->SetMax(-target_min);
5138 }
5139 }
5140 }
5141
5142 std::string DebugString() const override {
5143 return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5144 target_var_->DebugString());
5145 }
5146
5147 void Accept(ModelVisitor* const visitor) const override {
5148 visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5149 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5150 sub_);
5151 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5152 target_var_);
5153 visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5154 }
5155
5156 private:
5157 IntVar* const sub_;
5158};
5159
5160class IntAbs : public BaseIntExpr {
5161 public:
5162 IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5163
5164 ~IntAbs() override {}
5165
5166 int64_t Min() const override {
5167 int64_t emin = 0;
5168 int64_t emax = 0;
5169 expr_->Range(&emin, &emax);
5170 if (emin >= 0) {
5171 return emin;
5172 }
5173 if (emax <= 0) {
5174 return -emax;
5175 }
5176 return 0;
5177 }
5178
5179 void SetMin(int64_t m) override {
5180 if (m > 0) {
5181 int64_t emin = 0;
5182 int64_t emax = 0;
5183 expr_->Range(&emin, &emax);
5184 if (emin > -m) {
5185 expr_->SetMin(m);
5186 } else if (emax < m) {
5187 expr_->SetMax(-m);
5188 }
5189 }
5190 }
5191
5192 int64_t Max() const override {
5193 int64_t emin = 0;
5194 int64_t emax = 0;
5195 expr_->Range(&emin, &emax);
5196 return std::max(-emin, emax);
5197 }
5198
5199 void SetMax(int64_t m) override { expr_->SetRange(-m, m); }
5200
5201 void SetRange(int64_t mi, int64_t ma) override {
5202 expr_->SetRange(-ma, ma);
5203 if (mi > 0) {
5204 int64_t emin = 0;
5205 int64_t emax = 0;
5206 expr_->Range(&emin, &emax);
5207 if (emin > -mi) {
5208 expr_->SetMin(mi);
5209 } else if (emax < mi) {
5210 expr_->SetMax(-mi);
5211 }
5212 }
5213 }
5214
5215 void Range(int64_t* mi, int64_t* ma) override {
5216 int64_t emin = 0;
5217 int64_t emax = 0;
5218 expr_->Range(&emin, &emax);
5219 if (emin >= 0) {
5220 *mi = emin;
5221 *ma = emax;
5222 } else if (emax <= 0) {
5223 *mi = -emax;
5224 *ma = -emin;
5225 } else {
5226 *mi = 0;
5227 *ma = std::max(-emin, emax);
5228 }
5229 }
5230
5231 bool Bound() const override { return expr_->Bound(); }
5232
5233 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5234
5235 std::string name() const override {
5236 return absl::StrFormat("IntAbs(%s)", expr_->name());
5237 }
5238
5239 std::string DebugString() const override {
5240 return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5241 }
5242
5243 void Accept(ModelVisitor* const visitor) const override {
5244 visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5245 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5246 expr_);
5247 visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5248 }
5249
5250 IntVar* CastToVar() override {
5251 int64_t min_value = 0;
5252 int64_t max_value = 0;
5253 Range(&min_value, &max_value);
5254 Solver* const s = solver();
5255 const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5256 IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5257 CastConstraint* const ct =
5258 s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5259 s->AddCastConstraint(ct, target, this);
5260 return target;
5261 }
5262
5263 private:
5264 IntExpr* const expr_;
5265};
5266
5267// ----- Square -----
5268
5269// TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5270class IntSquare : public BaseIntExpr {
5271 public:
5272 IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5273 ~IntSquare() override {}
5274
5275 int64_t Min() const override {
5276 const int64_t emin = expr_->Min();
5277 if (emin >= 0) {
5278 return emin >= std::numeric_limits<int32_t>::max()
5279 ? std::numeric_limits<int64_t>::max()
5280 : emin * emin;
5281 }
5282 const int64_t emax = expr_->Max();
5283 if (emax < 0) {
5284 return emax <= -std::numeric_limits<int32_t>::max()
5285 ? std::numeric_limits<int64_t>::max()
5286 : emax * emax;
5287 }
5288 return 0LL;
5289 }
5290 void SetMin(int64_t m) override {
5291 if (m <= 0) {
5292 return;
5293 }
5294 // TODO(user): What happens if m is kint64max?
5295 const int64_t emin = expr_->Min();
5296 const int64_t emax = expr_->Max();
5297 const int64_t root =
5298 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5299 if (emin >= 0) {
5300 expr_->SetMin(root);
5301 } else if (emax <= 0) {
5302 expr_->SetMax(-root);
5303 } else if (expr_->IsVar()) {
5304 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5305 }
5306 }
5307 int64_t Max() const override {
5308 const int64_t emax = expr_->Max();
5309 const int64_t emin = expr_->Min();
5310 if (emax >= std::numeric_limits<int32_t>::max() ||
5311 emin <= -std::numeric_limits<int32_t>::max()) {
5312 return std::numeric_limits<int64_t>::max();
5313 }
5314 return std::max(emin * emin, emax * emax);
5315 }
5316 void SetMax(int64_t m) override {
5317 if (m < 0) {
5318 solver()->Fail();
5319 }
5320 if (m == std::numeric_limits<int64_t>::max()) {
5321 return;
5322 }
5323 const int64_t root =
5324 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5325 expr_->SetRange(-root, root);
5326 }
5327 bool Bound() const override { return expr_->Bound(); }
5328 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5329 std::string name() const override {
5330 return absl::StrFormat("IntSquare(%s)", expr_->name());
5331 }
5332 std::string DebugString() const override {
5333 return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5334 }
5335
5336 void Accept(ModelVisitor* const visitor) const override {
5337 visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5338 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5339 expr_);
5340 visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5341 }
5342
5343 IntExpr* expr() const { return expr_; }
5344
5345 protected:
5346 IntExpr* const expr_;
5347};
5348
5349class PosIntSquare : public IntSquare {
5350 public:
5351 PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5352 ~PosIntSquare() override {}
5353
5354 int64_t Min() const override {
5355 const int64_t emin = expr_->Min();
5356 return emin >= std::numeric_limits<int32_t>::max()
5357 ? std::numeric_limits<int64_t>::max()
5358 : emin * emin;
5359 }
5360 void SetMin(int64_t m) override {
5361 if (m <= 0) {
5362 return;
5363 }
5364 int64_t root = static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5365 if (CapProd(root, root) < m) {
5366 root++;
5367 }
5368 expr_->SetMin(root);
5369 }
5370 int64_t Max() const override {
5371 const int64_t emax = expr_->Max();
5372 return emax >= std::numeric_limits<int32_t>::max()
5373 ? std::numeric_limits<int64_t>::max()
5374 : emax * emax;
5375 }
5376 void SetMax(int64_t m) override {
5377 if (m < 0) {
5378 solver()->Fail();
5379 }
5380 if (m == std::numeric_limits<int64_t>::max()) {
5381 return;
5382 }
5383 int64_t root = static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5384 if (CapProd(root, root) > m) {
5385 root--;
5386 }
5387
5388 expr_->SetMax(root);
5389 }
5390};
5391
5392// ----- EvenPower -----
5393
5394int64_t IntPower(int64_t value, int64_t power) {
5395 int64_t result = value;
5396 // TODO(user): Speed that up.
5397 for (int i = 1; i < power; ++i) {
5398 result *= value;
5399 }
5400 return result;
5401}
5402
5403int64_t OverflowLimit(int64_t power) {
5404 return static_cast<int64_t>(floor(exp(
5405 log(static_cast<double>(std::numeric_limits<int64_t>::max())) / power)));
5406}
5407
5408class BasePower : public BaseIntExpr {
5409 public:
5410 BasePower(Solver* const s, IntExpr* const e, int64_t n)
5411 : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5412 CHECK_GT(n, 0);
5413 }
5414
5415 ~BasePower() override {}
5416
5417 bool Bound() const override { return expr_->Bound(); }
5418
5419 IntExpr* expr() const { return expr_; }
5420
5421 int64_t exponant() const { return pow_; }
5422
5423 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5424
5425 std::string name() const override {
5426 return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5427 }
5428
5429 std::string DebugString() const override {
5430 return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5431 }
5432
5433 void Accept(ModelVisitor* const visitor) const override {
5434 visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5435 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5436 expr_);
5437 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5438 visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5439 }
5440
5441 protected:
5442 int64_t Pown(int64_t value) const {
5443 if (value >= limit_) {
5444 return std::numeric_limits<int64_t>::max();
5445 }
5446 if (value <= -limit_) {
5447 if (pow_ % 2 == 0) {
5448 return std::numeric_limits<int64_t>::max();
5449 } else {
5450 return std::numeric_limits<int64_t>::min();
5451 }
5452 }
5453 return IntPower(value, pow_);
5454 }
5455
5456 int64_t SqrnDown(int64_t value) const {
5457 if (value == std::numeric_limits<int64_t>::min()) {
5458 return std::numeric_limits<int64_t>::min();
5459 }
5460 if (value == std::numeric_limits<int64_t>::max()) {
5461 return std::numeric_limits<int64_t>::max();
5462 }
5463 int64_t res = 0;
5464 const double d_value = static_cast<double>(value);
5465 if (value >= 0) {
5466 const double sq = exp(log(d_value) / pow_);
5467 res = static_cast<int64_t>(floor(sq));
5468 } else {
5469 CHECK_EQ(1, pow_ % 2);
5470 const double sq = exp(log(-d_value) / pow_);
5471 res = -static_cast<int64_t>(ceil(sq));
5472 }
5473 const int64_t pow_res = Pown(res + 1);
5474 if (pow_res <= value) {
5475 return res + 1;
5476 } else {
5477 return res;
5478 }
5479 }
5480
5481 int64_t SqrnUp(int64_t value) const {
5482 if (value == std::numeric_limits<int64_t>::min()) {
5483 return std::numeric_limits<int64_t>::min();
5484 }
5485 if (value == std::numeric_limits<int64_t>::max()) {
5486 return std::numeric_limits<int64_t>::max();
5487 }
5488 int64_t res = 0;
5489 const double d_value = static_cast<double>(value);
5490 if (value >= 0) {
5491 const double sq = exp(log(d_value) / pow_);
5492 res = static_cast<int64_t>(ceil(sq));
5493 } else {
5494 CHECK_EQ(1, pow_ % 2);
5495 const double sq = exp(log(-d_value) / pow_);
5496 res = -static_cast<int64_t>(floor(sq));
5497 }
5498 const int64_t pow_res = Pown(res - 1);
5499 if (pow_res >= value) {
5500 return res - 1;
5501 } else {
5502 return res;
5503 }
5504 }
5505
5506 IntExpr* const expr_;
5507 const int64_t pow_;
5508 const int64_t limit_;
5509};
5510
5511class IntEvenPower : public BasePower {
5512 public:
5513 IntEvenPower(Solver* const s, IntExpr* const e, int64_t n)
5514 : BasePower(s, e, n) {
5515 CHECK_EQ(0, n % 2);
5516 }
5517
5518 ~IntEvenPower() override {}
5519
5520 int64_t Min() const override {
5521 int64_t emin = 0;
5522 int64_t emax = 0;
5523 expr_->Range(&emin, &emax);
5524 if (emin >= 0) {
5525 return Pown(emin);
5526 }
5527 if (emax < 0) {
5528 return Pown(emax);
5529 }
5530 return 0LL;
5531 }
5532 void SetMin(int64_t m) override {
5533 if (m <= 0) {
5534 return;
5535 }
5536 int64_t emin = 0;
5537 int64_t emax = 0;
5538 expr_->Range(&emin, &emax);
5539 const int64_t root = SqrnUp(m);
5540 if (emin > -root) {
5541 expr_->SetMin(root);
5542 } else if (emax < root) {
5543 expr_->SetMax(-root);
5544 } else if (expr_->IsVar()) {
5545 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5546 }
5547 }
5548
5549 int64_t Max() const override {
5550 return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5551 }
5552
5553 void SetMax(int64_t m) override {
5554 if (m < 0) {
5555 solver()->Fail();
5556 }
5557 if (m == std::numeric_limits<int64_t>::max()) {
5558 return;
5559 }
5560 const int64_t root = SqrnDown(m);
5561 expr_->SetRange(-root, root);
5562 }
5563};
5564
5565class PosIntEvenPower : public BasePower {
5566 public:
5567 PosIntEvenPower(Solver* const s, IntExpr* const e, int64_t pow)
5568 : BasePower(s, e, pow) {
5569 CHECK_EQ(0, pow % 2);
5570 }
5571
5572 ~PosIntEvenPower() override {}
5573
5574 int64_t Min() const override { return Pown(expr_->Min()); }
5575
5576 void SetMin(int64_t m) override {
5577 if (m <= 0) {
5578 return;
5579 }
5580 expr_->SetMin(SqrnUp(m));
5581 }
5582 int64_t Max() const override { return Pown(expr_->Max()); }
5583
5584 void SetMax(int64_t m) override {
5585 if (m < 0) {
5586 solver()->Fail();
5587 }
5588 if (m == std::numeric_limits<int64_t>::max()) {
5589 return;
5590 }
5591 expr_->SetMax(SqrnDown(m));
5592 }
5593};
5594
5595class IntOddPower : public BasePower {
5596 public:
5597 IntOddPower(Solver* const s, IntExpr* const e, int64_t n)
5598 : BasePower(s, e, n) {
5599 CHECK_EQ(1, n % 2);
5600 }
5601
5602 ~IntOddPower() override {}
5603
5604 int64_t Min() const override { return Pown(expr_->Min()); }
5605
5606 void SetMin(int64_t m) override { expr_->SetMin(SqrnUp(m)); }
5607
5608 int64_t Max() const override { return Pown(expr_->Max()); }
5609
5610 void SetMax(int64_t m) override { expr_->SetMax(SqrnDown(m)); }
5611};
5612
5613// ----- Min(expr, expr) -----
5614
5615class MinIntExpr : public BaseIntExpr {
5616 public:
5617 MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5618 : BaseIntExpr(s), left_(l), right_(r) {}
5619 ~MinIntExpr() override {}
5620 int64_t Min() const override {
5621 const int64_t lmin = left_->Min();
5622 const int64_t rmin = right_->Min();
5623 return std::min(lmin, rmin);
5624 }
5625 void SetMin(int64_t m) override {
5626 left_->SetMin(m);
5627 right_->SetMin(m);
5628 }
5629 int64_t Max() const override {
5630 const int64_t lmax = left_->Max();
5631 const int64_t rmax = right_->Max();
5632 return std::min(lmax, rmax);
5633 }
5634 void SetMax(int64_t m) override {
5635 if (left_->Min() > m) {
5636 right_->SetMax(m);
5637 }
5638 if (right_->Min() > m) {
5639 left_->SetMax(m);
5640 }
5641 }
5642 std::string name() const override {
5643 return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5644 }
5645 std::string DebugString() const override {
5646 return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5647 right_->DebugString());
5648 }
5649 void WhenRange(Demon* d) override {
5650 left_->WhenRange(d);
5651 right_->WhenRange(d);
5652 }
5653
5654 void Accept(ModelVisitor* const visitor) const override {
5655 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5656 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5657 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5658 right_);
5659 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5660 }
5661
5662 private:
5663 IntExpr* const left_;
5664 IntExpr* const right_;
5665};
5666
5667// ----- Min(expr, constant) -----
5668
5669class MinCstIntExpr : public BaseIntExpr {
5670 public:
5671 MinCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5672 : BaseIntExpr(s), expr_(e), value_(v) {}
5673
5674 ~MinCstIntExpr() override {}
5675
5676 int64_t Min() const override { return std::min(expr_->Min(), value_); }
5677
5678 void SetMin(int64_t m) override {
5679 if (m > value_) {
5680 solver()->Fail();
5681 }
5682 expr_->SetMin(m);
5683 }
5684
5685 int64_t Max() const override { return std::min(expr_->Max(), value_); }
5686
5687 void SetMax(int64_t m) override {
5688 if (value_ > m) {
5689 expr_->SetMax(m);
5690 }
5691 }
5692
5693 bool Bound() const override {
5694 return (expr_->Bound() || expr_->Min() >= value_);
5695 }
5696
5697 std::string name() const override {
5698 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5699 }
5700
5701 std::string DebugString() const override {
5702 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5703 value_);
5704 }
5705
5706 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5707
5708 void Accept(ModelVisitor* const visitor) const override {
5709 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5710 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5711 expr_);
5712 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5713 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5714 }
5715
5716 private:
5717 IntExpr* const expr_;
5718 const int64_t value_;
5719};
5720
5721// ----- Max(expr, expr) -----
5722
5723class MaxIntExpr : public BaseIntExpr {
5724 public:
5725 MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5726 : BaseIntExpr(s), left_(l), right_(r) {}
5727
5728 ~MaxIntExpr() override {}
5729
5730 int64_t Min() const override { return std::max(left_->Min(), right_->Min()); }
5731
5732 void SetMin(int64_t m) override {
5733 if (left_->Max() < m) {
5734 right_->SetMin(m);
5735 } else {
5736 if (right_->Max() < m) {
5737 left_->SetMin(m);
5738 }
5739 }
5740 }
5741
5742 int64_t Max() const override { return std::max(left_->Max(), right_->Max()); }
5743
5744 void SetMax(int64_t m) override {
5745 left_->SetMax(m);
5746 right_->SetMax(m);
5747 }
5748
5749 std::string name() const override {
5750 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5751 }
5752
5753 std::string DebugString() const override {
5754 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5755 right_->DebugString());
5756 }
5757
5758 void WhenRange(Demon* d) override {
5759 left_->WhenRange(d);
5760 right_->WhenRange(d);
5761 }
5762
5763 void Accept(ModelVisitor* const visitor) const override {
5764 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5765 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5766 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5767 right_);
5768 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5769 }
5770
5771 private:
5772 IntExpr* const left_;
5773 IntExpr* const right_;
5774};
5775
5776// ----- Max(expr, constant) -----
5777
5778class MaxCstIntExpr : public BaseIntExpr {
5779 public:
5780 MaxCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5781 : BaseIntExpr(s), expr_(e), value_(v) {}
5782
5783 ~MaxCstIntExpr() override {}
5784
5785 int64_t Min() const override { return std::max(expr_->Min(), value_); }
5786
5787 void SetMin(int64_t m) override {
5788 if (value_ < m) {
5789 expr_->SetMin(m);
5790 }
5791 }
5792
5793 int64_t Max() const override { return std::max(expr_->Max(), value_); }
5794
5795 void SetMax(int64_t m) override {
5796 if (m < value_) {
5797 solver()->Fail();
5798 }
5799 expr_->SetMax(m);
5800 }
5801
5802 bool Bound() const override {
5803 return (expr_->Bound() || expr_->Max() <= value_);
5804 }
5805
5806 std::string name() const override {
5807 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5808 }
5809
5810 std::string DebugString() const override {
5811 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5812 value_);
5813 }
5814
5815 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5816
5817 void Accept(ModelVisitor* const visitor) const override {
5818 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5819 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5820 expr_);
5821 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5822 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5823 }
5824
5825 private:
5826 IntExpr* const expr_;
5827 const int64_t value_;
5828};
5829
5830// ----- Convex Piecewise -----
5831
5832// This class is a very simple convex piecewise linear function. The
5833// argument of the function is the expression. Between early_date and
5834// late_date, the value of the function is 0. Before early date, it
5835// is affine and the cost is early_cost * (early_date - x). After
5836// late_date, the cost is late_cost * (x - late_date).
5837
5838class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5839 public:
5840 SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64_t ec,
5841 int64_t ed, int64_t ld, int64_t lc)
5842 : BaseIntExpr(s),
5843 expr_(e),
5844 early_cost_(ec),
5845 early_date_(ec == 0 ? std::numeric_limits<int64_t>::min() : ed),
5846 late_date_(lc == 0 ? std::numeric_limits<int64_t>::max() : ld),
5847 late_cost_(lc) {
5848 DCHECK_GE(ec, int64_t{0});
5849 DCHECK_GE(lc, int64_t{0});
5850 DCHECK_GE(ld, ed);
5851
5852 // If the penalty is 0, we can push the "confort zone or zone
5853 // of no cost towards infinity.
5854 }
5855
5856 ~SimpleConvexPiecewiseExpr() override {}
5857
5858 int64_t Min() const override {
5859 const int64_t vmin = expr_->Min();
5860 const int64_t vmax = expr_->Max();
5861 if (vmin >= late_date_) {
5862 return (vmin - late_date_) * late_cost_;
5863 } else if (vmax <= early_date_) {
5864 return (early_date_ - vmax) * early_cost_;
5865 } else {
5866 return 0LL;
5867 }
5868 }
5869
5870 void SetMin(int64_t m) override {
5871 if (m <= 0) {
5872 return;
5873 }
5874 int64_t vmin = 0;
5875 int64_t vmax = 0;
5876 expr_->Range(&vmin, &vmax);
5877
5878 const int64_t rb =
5879 (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5880 const int64_t lb =
5881 (early_cost_ == 0 ? vmin
5882 : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5883
5884 if (expr_->IsVar()) {
5885 expr_->Var()->RemoveInterval(lb, rb);
5886 }
5887 }
5888
5889 int64_t Max() const override {
5890 const int64_t vmin = expr_->Min();
5891 const int64_t vmax = expr_->Max();
5892 const int64_t mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5893 const int64_t ml =
5894 vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5895 return std::max(mr, ml);
5896 }
5897
5898 void SetMax(int64_t m) override {
5899 if (m < 0) {
5900 solver()->Fail();
5901 }
5902 if (late_cost_ != 0LL) {
5903 const int64_t rb = late_date_ + PosIntDivDown(m, late_cost_);
5904 if (early_cost_ != 0LL) {
5905 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5906 expr_->SetRange(lb, rb);
5907 } else {
5908 expr_->SetMax(rb);
5909 }
5910 } else {
5911 if (early_cost_ != 0LL) {
5912 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5913 expr_->SetMin(lb);
5914 }
5915 }
5916 }
5917
5918 std::string name() const override {
5919 return absl::StrFormat(
5920 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5921 expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5922 }
5923
5924 std::string DebugString() const override {
5925 return absl::StrFormat(
5926 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5927 expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5928 }
5929
5930 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5931
5932 void Accept(ModelVisitor* const visitor) const override {
5933 visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5934 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5935 expr_);
5936 visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5937 early_cost_);
5938 visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5939 early_date_);
5940 visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5941 visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5942 visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5943 }
5944
5945 private:
5946 IntExpr* const expr_;
5947 const int64_t early_cost_;
5948 const int64_t early_date_;
5949 const int64_t late_date_;
5950 const int64_t late_cost_;
5951};
5952
5953// ----- Semi Continuous -----
5954
5955class SemiContinuousExpr : public BaseIntExpr {
5956 public:
5957 SemiContinuousExpr(Solver* const s, IntExpr* const e, int64_t fixed_charge,
5958 int64_t step)
5959 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5960 DCHECK_GE(fixed_charge, int64_t{0});
5961 DCHECK_GT(step, int64_t{0});
5962 }
5963
5964 ~SemiContinuousExpr() override {}
5965
5966 int64_t Value(int64_t x) const {
5967 if (x <= 0) {
5968 return 0;
5969 } else {
5970 return CapAdd(fixed_charge_, CapProd(x, step_));
5971 }
5972 }
5973
5974 int64_t Min() const override { return Value(expr_->Min()); }
5975
5976 void SetMin(int64_t m) override {
5977 if (m >= CapAdd(fixed_charge_, step_)) {
5978 const int64_t y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5979 expr_->SetMin(y);
5980 } else if (m > 0) {
5981 expr_->SetMin(1);
5982 }
5983 }
5984
5985 int64_t Max() const override { return Value(expr_->Max()); }
5986
5987 void SetMax(int64_t m) override {
5988 if (m < 0) {
5989 solver()->Fail();
5990 }
5991 if (m == std::numeric_limits<int64_t>::max()) {
5992 return;
5993 }
5994 if (m < CapAdd(fixed_charge_, step_)) {
5995 expr_->SetMax(0);
5996 } else {
5997 const int64_t y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5998 expr_->SetMax(y);
5999 }
6000 }
6001
6002 std::string name() const override {
6003 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
6004 expr_->name(), fixed_charge_, step_);
6005 }
6006
6007 std::string DebugString() const override {
6008 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
6009 expr_->DebugString(), fixed_charge_, step_);
6010 }
6011
6012 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6013
6014 void Accept(ModelVisitor* const visitor) const override {
6015 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6016 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6017 expr_);
6018 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6019 fixed_charge_);
6020 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
6021 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6022 }
6023
6024 private:
6025 IntExpr* const expr_;
6026 const int64_t fixed_charge_;
6027 const int64_t step_;
6028};
6029
6030class SemiContinuousStepOneExpr : public BaseIntExpr {
6031 public:
6032 SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
6033 int64_t fixed_charge)
6034 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6035 DCHECK_GE(fixed_charge, int64_t{0});
6036 }
6037
6038 ~SemiContinuousStepOneExpr() override {}
6039
6040 int64_t Value(int64_t x) const {
6041 if (x <= 0) {
6042 return 0;
6043 } else {
6044 return fixed_charge_ + x;
6045 }
6046 }
6047
6048 int64_t Min() const override { return Value(expr_->Min()); }
6049
6050 void SetMin(int64_t m) override {
6051 if (m >= fixed_charge_ + 1) {
6052 expr_->SetMin(m - fixed_charge_);
6053 } else if (m > 0) {
6054 expr_->SetMin(1);
6055 }
6056 }
6057
6058 int64_t Max() const override { return Value(expr_->Max()); }
6059
6060 void SetMax(int64_t m) override {
6061 if (m < 0) {
6062 solver()->Fail();
6063 }
6064 if (m < fixed_charge_ + 1) {
6065 expr_->SetMax(0);
6066 } else {
6067 expr_->SetMax(m - fixed_charge_);
6068 }
6069 }
6070
6071 std::string name() const override {
6072 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6073 expr_->name(), fixed_charge_);
6074 }
6075
6076 std::string DebugString() const override {
6077 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6078 expr_->DebugString(), fixed_charge_);
6079 }
6080
6081 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6082
6083 void Accept(ModelVisitor* const visitor) const override {
6084 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6085 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6086 expr_);
6087 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6088 fixed_charge_);
6089 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6090 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6091 }
6092
6093 private:
6094 IntExpr* const expr_;
6095 const int64_t fixed_charge_;
6096};
6097
6098class SemiContinuousStepZeroExpr : public BaseIntExpr {
6099 public:
6100 SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6101 int64_t fixed_charge)
6102 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6103 DCHECK_GT(fixed_charge, int64_t{0});
6104 }
6105
6106 ~SemiContinuousStepZeroExpr() override {}
6107
6108 int64_t Value(int64_t x) const {
6109 if (x <= 0) {
6110 return 0;
6111 } else {
6112 return fixed_charge_;
6113 }
6114 }
6115
6116 int64_t Min() const override { return Value(expr_->Min()); }
6117
6118 void SetMin(int64_t m) override {
6119 if (m >= fixed_charge_) {
6120 solver()->Fail();
6121 } else if (m > 0) {
6122 expr_->SetMin(1);
6123 }
6124 }
6125
6126 int64_t Max() const override { return Value(expr_->Max()); }
6127
6128 void SetMax(int64_t m) override {
6129 if (m < 0) {
6130 solver()->Fail();
6131 }
6132 if (m < fixed_charge_) {
6133 expr_->SetMax(0);
6134 }
6135 }
6136
6137 std::string name() const override {
6138 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6139 expr_->name(), fixed_charge_);
6140 }
6141
6142 std::string DebugString() const override {
6143 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6144 expr_->DebugString(), fixed_charge_);
6145 }
6146
6147 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6148
6149 void Accept(ModelVisitor* const visitor) const override {
6150 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6151 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6152 expr_);
6153 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6154 fixed_charge_);
6155 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6156 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6157 }
6158
6159 private:
6160 IntExpr* const expr_;
6161 const int64_t fixed_charge_;
6162};
6163
6164// This constraints links an expression and the variable it is casted into
6165class LinkExprAndVar : public CastConstraint {
6166 public:
6167 LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6168 : CastConstraint(s, var), expr_(expr) {}
6169
6170 ~LinkExprAndVar() override {}
6171
6172 void Post() override {
6173 Solver* const s = solver();
6174 Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6175 expr_->WhenRange(d);
6176 target_var_->WhenRange(d);
6177 }
6178
6179 void InitialPropagate() override {
6180 expr_->SetRange(target_var_->Min(), target_var_->Max());
6181 int64_t l, u;
6182 expr_->Range(&l, &u);
6183 target_var_->SetRange(l, u);
6184 }
6185
6186 std::string DebugString() const override {
6187 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6188 target_var_->DebugString());
6189 }
6190
6191 void Accept(ModelVisitor* const visitor) const override {
6192 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6193 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6194 expr_);
6195 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6196 target_var_);
6197 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6198 }
6199
6200 private:
6201 IntExpr* const expr_;
6202};
6203
6204// ----- Conditional Expression -----
6205
6206class ExprWithEscapeValue : public BaseIntExpr {
6207 public:
6208 ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6209 int64_t unperformed_value)
6210 : BaseIntExpr(s),
6211 condition_(c),
6212 expression_(e),
6213 unperformed_value_(unperformed_value) {}
6214
6215 // This type is neither copyable nor movable.
6216 ExprWithEscapeValue(const ExprWithEscapeValue&) = delete;
6217 ExprWithEscapeValue& operator=(const ExprWithEscapeValue&) = delete;
6218
6219 ~ExprWithEscapeValue() override {}
6220
6221 int64_t Min() const override {
6222 if (condition_->Min() == 1) {
6223 return expression_->Min();
6224 } else if (condition_->Max() == 1) {
6225 return std::min(unperformed_value_, expression_->Min());
6226 } else {
6227 return unperformed_value_;
6228 }
6229 }
6230
6231 void SetMin(int64_t m) override {
6232 if (m > unperformed_value_) {
6233 condition_->SetValue(1);
6234 expression_->SetMin(m);
6235 } else if (condition_->Min() == 1) {
6236 expression_->SetMin(m);
6237 } else if (m > expression_->Max()) {
6238 condition_->SetValue(0);
6239 }
6240 }
6241
6242 int64_t Max() const override {
6243 if (condition_->Min() == 1) {
6244 return expression_->Max();
6245 } else if (condition_->Max() == 1) {
6246 return std::max(unperformed_value_, expression_->Max());
6247 } else {
6248 return unperformed_value_;
6249 }
6250 }
6251
6252 void SetMax(int64_t m) override {
6253 if (m < unperformed_value_) {
6254 condition_->SetValue(1);
6255 expression_->SetMax(m);
6256 } else if (condition_->Min() == 1) {
6257 expression_->SetMax(m);
6258 } else if (m < expression_->Min()) {
6259 condition_->SetValue(0);
6260 }
6261 }
6262
6263 void SetRange(int64_t mi, int64_t ma) override {
6264 if (ma < unperformed_value_ || mi > unperformed_value_) {
6265 condition_->SetValue(1);
6266 expression_->SetRange(mi, ma);
6267 } else if (condition_->Min() == 1) {
6268 expression_->SetRange(mi, ma);
6269 } else if (ma < expression_->Min() || mi > expression_->Max()) {
6270 condition_->SetValue(0);
6271 }
6272 }
6273
6274 void SetValue(int64_t v) override {
6275 if (v != unperformed_value_) {
6276 condition_->SetValue(1);
6277 expression_->SetValue(v);
6278 } else if (condition_->Min() == 1) {
6279 expression_->SetValue(v);
6280 } else if (v < expression_->Min() || v > expression_->Max()) {
6281 condition_->SetValue(0);
6282 }
6283 }
6284
6285 bool Bound() const override {
6286 return condition_->Max() == 0 || expression_->Bound();
6287 }
6288
6289 void WhenRange(Demon* d) override {
6290 expression_->WhenRange(d);
6291 condition_->WhenBound(d);
6292 }
6293
6294 std::string DebugString() const override {
6295 return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6296 condition_->DebugString(),
6297 expression_->DebugString(), unperformed_value_);
6298 }
6299
6300 void Accept(ModelVisitor* const visitor) const override {
6301 visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6302 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6303 condition_);
6304 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6305 expression_);
6306 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6307 unperformed_value_);
6308 visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6309 }
6310
6311 private:
6312 IntVar* const condition_;
6313 IntExpr* const expression_;
6314 const int64_t unperformed_value_;
6315};
6316
6317// ----- This is a specialized case when the variable exact type is known -----
6318class LinkExprAndDomainIntVar : public CastConstraint {
6319 public:
6320 LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6321 DomainIntVar* const var)
6322 : CastConstraint(s, var),
6323 expr_(expr),
6324 cached_min_(std::numeric_limits<int64_t>::min()),
6325 cached_max_(std::numeric_limits<int64_t>::max()),
6326 fail_stamp_(uint64_t{0}) {}
6327
6328 ~LinkExprAndDomainIntVar() override {}
6329
6330 DomainIntVar* var() const {
6331 return reinterpret_cast<DomainIntVar*>(target_var_);
6332 }
6333
6334 void Post() override {
6335 Solver* const s = solver();
6336 Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6337 expr_->WhenRange(d);
6338 Demon* const target_var_demon = MakeConstraintDemon0(
6339 solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6340 target_var_->WhenRange(target_var_demon);
6341 }
6342
6343 void InitialPropagate() override {
6344 expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6345 expr_->Range(&cached_min_, &cached_max_);
6346 var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6347 }
6348
6349 void Propagate() {
6350 if (var()->min_.Value() > cached_min_ ||
6351 var()->max_.Value() < cached_max_ ||
6352 solver()->fail_stamp() != fail_stamp_) {
6353 InitialPropagate();
6354 fail_stamp_ = solver()->fail_stamp();
6355 }
6356 }
6357
6358 std::string DebugString() const override {
6359 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6360 target_var_->DebugString());
6361 }
6362
6363 void Accept(ModelVisitor* const visitor) const override {
6364 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6365 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6366 expr_);
6367 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6368 target_var_);
6369 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6370 }
6371
6372 private:
6373 IntExpr* const expr_;
6374 int64_t cached_min_;
6375 int64_t cached_max_;
6376 uint64_t fail_stamp_;
6377};
6378} // namespace
6379
6380// ----- Misc -----
6381
6383 return CondRevAlloc(solver(), reversible, new EmptyIterator());
6384}
6386 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
6387}
6388
6389// ----- API -----
6390
6392 DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6393 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6394 dvar->CleanInProcess();
6395}
6396
6397Constraint* SetIsEqual(IntVar* const var, absl::Span<const int64_t> values,
6398 const std::vector<IntVar*>& vars) {
6399 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6400 CHECK(dvar != nullptr);
6401 return dvar->SetIsEqual(values, vars);
6402}
6403
6405 absl::Span<const int64_t> values,
6406 const std::vector<IntVar*>& vars) {
6407 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6408 CHECK(dvar != nullptr);
6409 return dvar->SetIsGreaterOrEqual(values, vars);
6410}
6411
6413 DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6414 BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6415 boolean_var->RestoreValue();
6416}
6417
6418// ----- API -----
6419
6420IntVar* Solver::MakeIntVar(int64_t min, int64_t max, const std::string& name) {
6421 if (min == max) {
6422 return MakeIntConst(min, name);
6423 }
6424 if (min == 0 && max == 1) {
6425 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6426 } else if (CapSub(max, min) == 1) {
6427 const std::string inner_name = "inner_" + name;
6428 return RegisterIntVar(
6429 MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6430 ->VarWithName(name));
6431 } else {
6432 return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6433 }
6434}
6435
6436IntVar* Solver::MakeIntVar(int64_t min, int64_t max) {
6437 return MakeIntVar(min, max, "");
6438}
6439
6440IntVar* Solver::MakeBoolVar(const std::string& name) {
6441 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6442}
6443
6445 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6446}
6447
6448IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values,
6449 const std::string& name) {
6450 DCHECK(!values.empty());
6451 // Fast-track the case where we have a single value.
6452 if (values.size() == 1) return MakeIntConst(values[0], name);
6453 // Sort and remove duplicates.
6454 std::vector<int64_t> unique_sorted_values = values;
6455 gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6456 // Case when we have a single value, after clean-up.
6457 if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6458 // Case when the values are a dense interval of integers.
6459 if (unique_sorted_values.size() ==
6460 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6461 return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6462 name);
6463 }
6464 // Compute the GCD: if it's not 1, we can express the variable's domain as
6465 // the product of the GCD and of a domain with smaller values.
6466 int64_t gcd = 0;
6467 for (const int64_t v : unique_sorted_values) {
6468 if (gcd == 0) {
6469 gcd = std::abs(v);
6470 } else {
6471 gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6472 }
6473 if (gcd == 1) {
6474 // If it's 1, though, we can't do anything special, so we
6475 // immediately return a new DomainIntVar.
6476 return RegisterIntVar(
6477 RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6478 }
6479 }
6480 DCHECK_GT(gcd, 1);
6481 for (int64_t& v : unique_sorted_values) {
6482 DCHECK_EQ(0, v % gcd);
6483 v /= gcd;
6484 }
6485 const std::string new_name = name.empty() ? "" : "inner_" + name;
6486 // Catch the case where the divided values are a dense set of integers.
6487 IntVar* inner_intvar = nullptr;
6488 if (unique_sorted_values.size() ==
6489 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6490 inner_intvar = MakeIntVar(unique_sorted_values.front(),
6491 unique_sorted_values.back(), new_name);
6492 } else {
6493 inner_intvar = RegisterIntVar(
6494 RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6495 }
6496 return MakeProd(inner_intvar, gcd)->Var();
6497}
6498
6499IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values) {
6500 return MakeIntVar(values, "");
6501}
6502
6503IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6504 const std::string& name) {
6505 return MakeIntVar(ToInt64Vector(values), name);
6506}
6507
6508IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6509 return MakeIntVar(values, "");
6510}
6511
6512IntVar* Solver::MakeIntConst(int64_t val, const std::string& name) {
6513 // If IntConst is going to be named after its creation,
6514 // cp_share_int_consts should be set to false otherwise names can potentially
6515 // be overwritten.
6516 if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6517 val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6518 return cached_constants_[val - MIN_CACHED_INT_CONST];
6519 }
6520 return RevAlloc(new IntConst(this, val, name));
6521}
6522
6523IntVar* Solver::MakeIntConst(int64_t val) { return MakeIntConst(val, ""); }
6524
6525// ----- Int Var and associated methods -----
6526
6527namespace {
6528std::string IndexedName(absl::string_view prefix, int index, int max_index) {
6529#if 0
6530#if defined(_MSC_VER)
6531 const int digits = max_index > 0 ?
6532 static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6533 1;
6534#else
6535 const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6536#endif
6537 return absl::StrFormat("%s%0*d", prefix, digits, index);
6538#else
6539 return absl::StrCat(prefix, index);
6540#endif
6541}
6542} // namespace
6543
6544void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6545 const std::string& name,
6546 std::vector<IntVar*>* vars) {
6547 for (int i = 0; i < var_count; ++i) {
6548 vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6549 }
6550}
6551
6552void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6553 std::vector<IntVar*>* vars) {
6554 for (int i = 0; i < var_count; ++i) {
6555 vars->push_back(MakeIntVar(vmin, vmax));
6556 }
6557}
6558
6559IntVar** Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6560 const std::string& name) {
6561 IntVar** vars = new IntVar*[var_count];
6562 for (int i = 0; i < var_count; ++i) {
6563 vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6564 }
6565 return vars;
6566}
6567
6568void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6569 std::vector<IntVar*>* vars) {
6570 for (int i = 0; i < var_count; ++i) {
6571 vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6572 }
6573}
6574
6575void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6576 for (int i = 0; i < var_count; ++i) {
6577 vars->push_back(MakeBoolVar());
6578 }
6579}
6580
6581IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6582 IntVar** vars = new IntVar*[var_count];
6583 for (int i = 0; i < var_count; ++i) {
6584 vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6585 }
6586 return vars;
6587}
6588
6589void Solver::InitCachedIntConstants() {
6590 for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6591 cached_constants_[i - MIN_CACHED_INT_CONST] =
6592 RevAlloc(new IntConst(this, i, "")); // note the empty name
6593 }
6594}
6595
6596IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6597 CHECK_EQ(this, left->solver());
6598 CHECK_EQ(this, right->solver());
6599 if (right->Bound()) {
6600 return MakeSum(left, right->Min());
6601 }
6602 if (left->Bound()) {
6603 return MakeSum(right, left->Min());
6604 }
6605 if (left == right) {
6606 return MakeProd(left, 2);
6607 }
6608 IntExpr* cache = model_cache_->FindExprExprExpression(
6609 left, right, ModelCache::EXPR_EXPR_SUM);
6610 if (cache == nullptr) {
6611 cache = model_cache_->FindExprExprExpression(right, left,
6613 }
6614 if (cache != nullptr) {
6615 return cache;
6616 } else {
6617 IntExpr* const result =
6618 AddOverflows(left->Max(), right->Max()) ||
6619 AddOverflows(left->Min(), right->Min())
6620 ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6621 : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6622 model_cache_->InsertExprExprExpression(result, left, right,
6624 return result;
6625 }
6626}
6627
6628IntExpr* Solver::MakeSum(IntExpr* const expr, int64_t value) {
6629 CHECK_EQ(this, expr->solver());
6630 if (expr->Bound()) {
6631 return MakeIntConst(CapAdd(expr->Min(), value));
6632 }
6633 if (value == 0) {
6634 return expr;
6635 }
6636 IntExpr* result = Cache()->FindExprConstantExpression(
6637 expr, value, ModelCache::EXPR_CONSTANT_SUM);
6638 if (result == nullptr) {
6639 if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6640 !AddOverflows(value, expr->Min())) {
6641 IntVar* const var = expr->Var();
6642 switch (var->VarType()) {
6643 case DOMAIN_INT_VAR: {
6644 result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6645 this, reinterpret_cast<DomainIntVar*>(var), value)));
6646 break;
6647 }
6648 case CONST_VAR: {
6649 result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6650 break;
6651 }
6652 case VAR_ADD_CST: {
6653 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6654 IntVar* const sub_var = add_var->SubVar();
6655 const int64_t new_constant = value + add_var->Constant();
6656 if (new_constant == 0) {
6657 result = sub_var;
6658 } else {
6659 if (sub_var->VarType() == DOMAIN_INT_VAR) {
6660 DomainIntVar* const dvar =
6661 reinterpret_cast<DomainIntVar*>(sub_var);
6662 result = RegisterIntExpr(
6663 RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6664 } else {
6665 result = RegisterIntExpr(
6666 RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6667 }
6668 }
6669 break;
6670 }
6671 case CST_SUB_VAR: {
6672 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6673 IntVar* const sub_var = add_var->SubVar();
6674 const int64_t new_constant = value + add_var->Constant();
6675 result = RegisterIntExpr(
6676 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6677 break;
6678 }
6679 case OPP_VAR: {
6680 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6681 IntVar* const sub_var = add_var->SubVar();
6682 result =
6683 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6684 break;
6685 }
6686 default:
6687 result =
6688 RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6689 }
6690 } else {
6691 result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6692 }
6693 Cache()->InsertExprConstantExpression(result, expr, value,
6695 }
6696 return result;
6697}
6698
6699IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6700 CHECK_EQ(this, left->solver());
6701 CHECK_EQ(this, right->solver());
6702 if (left->Bound()) {
6703 return MakeDifference(left->Min(), right);
6704 }
6705 if (right->Bound()) {
6706 return MakeSum(left, -right->Min());
6707 }
6708 IntExpr* sub_left = nullptr;
6709 IntExpr* sub_right = nullptr;
6710 int64_t left_coef = 1;
6711 int64_t right_coef = 1;
6712 if (IsProduct(left, &sub_left, &left_coef) &&
6713 IsProduct(right, &sub_right, &right_coef)) {
6714 const int64_t abs_gcd =
6715 MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6716 if (abs_gcd != 0 && abs_gcd != 1) {
6717 return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6718 MakeProd(sub_right, right_coef / abs_gcd)),
6719 abs_gcd);
6720 }
6721 }
6722
6723 IntExpr* result = Cache()->FindExprExprExpression(
6725 if (result == nullptr) {
6726 if (!SubOverflows(left->Min(), right->Max()) &&
6727 !SubOverflows(left->Max(), right->Min())) {
6728 result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6729 } else {
6730 result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6731 }
6732 Cache()->InsertExprExprExpression(result, left, right,
6734 }
6735 return result;
6736}
6737
6738// warning: this is 'value - expr'.
6739IntExpr* Solver::MakeDifference(int64_t value, IntExpr* const expr) {
6740 CHECK_EQ(this, expr->solver());
6741 if (expr->Bound()) {
6742 return MakeIntConst(value - expr->Min());
6743 }
6744 if (value == 0) {
6745 return MakeOpposite(expr);
6746 }
6747 IntExpr* result = Cache()->FindExprConstantExpression(
6749 if (result == nullptr) {
6750 if (expr->IsVar() && expr->Min() != std::numeric_limits<int64_t>::min() &&
6751 !SubOverflows(value, expr->Min()) &&
6752 !SubOverflows(value, expr->Max())) {
6753 IntVar* const var = expr->Var();
6754 switch (var->VarType()) {
6755 case VAR_ADD_CST: {
6756 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6757 IntVar* const sub_var = add_var->SubVar();
6758 const int64_t new_constant = value - add_var->Constant();
6759 if (new_constant == 0) {
6760 result = sub_var;
6761 } else {
6762 result = RegisterIntExpr(
6763 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6764 }
6765 break;
6766 }
6767 case CST_SUB_VAR: {
6768 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6769 IntVar* const sub_var = add_var->SubVar();
6770 const int64_t new_constant = value - add_var->Constant();
6771 result = MakeSum(sub_var, new_constant);
6772 break;
6773 }
6774 case OPP_VAR: {
6775 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6776 IntVar* const sub_var = add_var->SubVar();
6777 result = MakeSum(sub_var, value);
6778 break;
6779 }
6780 default:
6781 result =
6782 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6783 }
6784 } else {
6785 result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6786 }
6787 Cache()->InsertExprConstantExpression(result, expr, value,
6789 }
6790 return result;
6791}
6792
6794 CHECK_EQ(this, expr->solver());
6795 if (expr->Bound()) {
6796 return MakeIntConst(CapOpp(expr->Min()));
6797 }
6798 IntExpr* result =
6799 Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6800 if (result == nullptr) {
6801 if (expr->IsVar()) {
6802 result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6803 } else {
6804 result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6805 }
6806 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6807 }
6808 return result;
6809}
6810
6811IntExpr* Solver::MakeProd(IntExpr* const expr, int64_t value) {
6812 CHECK_EQ(this, expr->solver());
6813 IntExpr* result = Cache()->FindExprConstantExpression(
6814 expr, value, ModelCache::EXPR_CONSTANT_PROD);
6815 if (result != nullptr) {
6816 return result;
6817 } else {
6818 IntExpr* m_expr = nullptr;
6819 int64_t coefficient = 1;
6820 if (IsProduct(expr, &m_expr, &coefficient)) {
6821 coefficient = CapProd(coefficient, value);
6822 } else {
6823 m_expr = expr;
6824 coefficient = value;
6825 }
6826 if (m_expr->Bound()) {
6827 return MakeIntConst(CapProd(coefficient, m_expr->Min()));
6828 } else if (coefficient == 1) {
6829 return m_expr;
6830 } else if (coefficient == -1) {
6831 return MakeOpposite(m_expr);
6832 } else if (coefficient > 0) {
6833 if (m_expr->Max() > std::numeric_limits<int64_t>::max() / coefficient ||
6834 m_expr->Min() < std::numeric_limits<int64_t>::min() / coefficient) {
6835 result = RegisterIntExpr(
6836 RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6837 } else {
6838 result = RegisterIntExpr(
6839 RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6840 }
6841 } else if (coefficient == 0) {
6842 result = MakeIntConst(0);
6843 } else { // coefficient < 0.
6844 result = RegisterIntExpr(
6845 RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6846 }
6847 if (m_expr->IsVar() &&
6848 !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6849 result = result->Var();
6850 }
6851 Cache()->InsertExprConstantExpression(result, expr, value,
6853 return result;
6854 }
6855}
6856
6857namespace {
6858void ExtractPower(IntExpr** const expr, int64_t* const exponant) {
6859 if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6860 BasePower* const power = dynamic_cast<BasePower*>(*expr);
6861 *expr = power->expr();
6862 *exponant = power->exponant();
6863 }
6864 if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6865 IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6866 *expr = power->expr();
6867 *exponant = 2;
6868 }
6869 if ((*expr)->IsVar()) {
6870 IntVar* const var = (*expr)->Var();
6871 IntExpr* const sub = var->solver()->CastExpression(var);
6872 if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6873 BasePower* const power = dynamic_cast<BasePower*>(sub);
6874 *expr = power->expr();
6875 *exponant = power->exponant();
6876 }
6877 if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6878 IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6879 *expr = power->expr();
6880 *exponant = 2;
6881 }
6882 }
6883}
6884
6885void ExtractProduct(IntExpr** const expr, int64_t* const coefficient,
6886 bool* modified) {
6887 if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6888 TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6889 *coefficient *= left_prod->Constant();
6890 *expr = left_prod->SubVar();
6891 *modified = true;
6892 } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6893 TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6894 *coefficient *= left_prod->Constant();
6895 *expr = left_prod->Expr();
6896 *modified = true;
6897 }
6898}
6899} // namespace
6900
6901IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6902 if (left->Bound()) {
6903 return MakeProd(right, left->Min());
6904 }
6905
6906 if (right->Bound()) {
6907 return MakeProd(left, right->Min());
6908 }
6909
6910 // ----- Discover squares and powers -----
6911
6912 IntExpr* m_left = left;
6913 IntExpr* m_right = right;
6914 int64_t left_exponant = 1;
6915 int64_t right_exponant = 1;
6916 ExtractPower(&m_left, &left_exponant);
6917 ExtractPower(&m_right, &right_exponant);
6918
6919 if (m_left == m_right) {
6920 return MakePower(m_left, left_exponant + right_exponant);
6921 }
6922
6923 // ----- Discover nested products -----
6924
6925 m_left = left;
6926 m_right = right;
6927 int64_t coefficient = 1;
6928 bool modified = false;
6929
6930 ExtractProduct(&m_left, &coefficient, &modified);
6931 ExtractProduct(&m_right, &coefficient, &modified);
6932 if (modified) {
6933 return MakeProd(MakeProd(m_left, m_right), coefficient);
6934 }
6935
6936 // ----- Standard build -----
6937
6938 CHECK_EQ(this, left->solver());
6939 CHECK_EQ(this, right->solver());
6940 IntExpr* result = model_cache_->FindExprExprExpression(
6941 left, right, ModelCache::EXPR_EXPR_PROD);
6942 if (result == nullptr) {
6943 result = model_cache_->FindExprExprExpression(right, left,
6945 }
6946 if (result != nullptr) {
6947 return result;
6948 }
6949 if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6950 if (right->Min() >= 0) {
6951 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6952 this, reinterpret_cast<BooleanVar*>(left), right)));
6953 } else {
6954 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6955 this, reinterpret_cast<BooleanVar*>(left), right)));
6956 }
6957 } else if (right->IsVar() &&
6958 reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6959 if (left->Min() >= 0) {
6960 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6961 this, reinterpret_cast<BooleanVar*>(right), left)));
6962 } else {
6963 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6964 this, reinterpret_cast<BooleanVar*>(right), left)));
6965 }
6966 } else if (left->Min() >= 0 && right->Min() >= 0) {
6967 if (CapProd(left->Max(), right->Max()) ==
6968 std::numeric_limits<int64_t>::max()) { // Potential overflow.
6969 result =
6970 RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6971 } else {
6972 result =
6973 RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6974 }
6975 } else {
6976 result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6977 }
6978 model_cache_->InsertExprExprExpression(result, left, right,
6980 return result;
6981}
6982
6983IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6984 CHECK(numerator != nullptr);
6985 CHECK(denominator != nullptr);
6986 if (denominator->Bound()) {
6987 return MakeDiv(numerator, denominator->Min());
6988 }
6989 IntExpr* result = model_cache_->FindExprExprExpression(
6990 numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6991 if (result != nullptr) {
6992 return result;
6993 }
6994
6995 if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6996 AddConstraint(MakeNonEquality(denominator, 0));
6997 }
6998
6999 if (denominator->Min() >= 0) {
7000 if (numerator->Min() >= 0) {
7001 result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
7002 } else {
7003 result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
7004 }
7005 } else if (denominator->Max() <= 0) {
7006 if (numerator->Max() <= 0) {
7007 result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
7008 MakeOpposite(denominator)));
7009 } else {
7010 result = MakeOpposite(RevAlloc(
7011 new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
7012 }
7013 } else {
7014 result = RevAlloc(new DivIntExpr(this, numerator, denominator));
7015 }
7016 model_cache_->InsertExprExprExpression(result, numerator, denominator,
7018 return result;
7019}
7020
7021IntExpr* Solver::MakeDiv(IntExpr* const expr, int64_t value) {
7022 CHECK(expr != nullptr);
7023 CHECK_EQ(this, expr->solver());
7024 if (expr->Bound()) {
7025 return MakeIntConst(expr->Min() / value);
7026 } else if (value == 1) {
7027 return expr;
7028 } else if (value == -1) {
7029 return MakeOpposite(expr);
7030 } else if (value > 0) {
7031 return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
7032 } else if (value == 0) {
7033 LOG(FATAL) << "Cannot divide by 0";
7034 return nullptr;
7035 } else {
7036 return RegisterIntExpr(
7037 MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
7038 // TODO(user) : implement special case.
7039 }
7040}
7041
7042Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
7043 if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
7044 Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
7045 }
7046 return RevAlloc(new IntAbsConstraint(this, var, abs_var));
7047}
7048
7050 CHECK_EQ(this, e->solver());
7051 if (e->Min() >= 0) {
7052 return e;
7053 } else if (e->Max() <= 0) {
7054 return MakeOpposite(e);
7055 }
7056 IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
7057 if (result == nullptr) {
7058 int64_t coefficient = 1;
7059 IntExpr* expr = nullptr;
7060 if (IsProduct(e, &expr, &coefficient)) {
7061 result = MakeProd(MakeAbs(expr), std::abs(coefficient));
7062 } else {
7063 result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
7064 }
7065 Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7066 }
7067 return result;
7068}
7069
7071 CHECK_EQ(this, expr->solver());
7072 if (expr->Bound()) {
7073 const int64_t v = expr->Min();
7074 return MakeIntConst(v * v);
7075 }
7076 IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7077 if (result == nullptr) {
7078 if (expr->Min() >= 0) {
7079 result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7080 } else {
7081 result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7082 }
7083 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7084 }
7085 return result;
7086}
7087
7088IntExpr* Solver::MakePower(IntExpr* const expr, int64_t n) {
7089 CHECK_EQ(this, expr->solver());
7090 CHECK_GE(n, 0);
7091 if (expr->Bound()) {
7092 const int64_t v = expr->Min();
7093 if (v >= OverflowLimit(n)) { // Overflow.
7094 return MakeIntConst(std::numeric_limits<int64_t>::max());
7095 }
7096 return MakeIntConst(IntPower(v, n));
7097 }
7098 switch (n) {
7099 case 0:
7100 return MakeIntConst(1);
7101 case 1:
7102 return expr;
7103 case 2:
7104 return MakeSquare(expr);
7105 default: {
7106 IntExpr* result = nullptr;
7107 if (n % 2 == 0) { // even.
7108 if (expr->Min() >= 0) {
7109 result =
7110 RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7111 } else {
7112 result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7113 }
7114 } else {
7115 result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7116 }
7117 return result;
7118 }
7119 }
7120}
7121
7122IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7123 CHECK_EQ(this, left->solver());
7124 CHECK_EQ(this, right->solver());
7125 if (left->Bound()) {
7126 return MakeMin(right, left->Min());
7127 }
7128 if (right->Bound()) {
7129 return MakeMin(left, right->Min());
7130 }
7131 if (left->Min() >= right->Max()) {
7132 return right;
7133 }
7134 if (right->Min() >= left->Max()) {
7135 return left;
7136 }
7137 return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7138}
7139
7140IntExpr* Solver::MakeMin(IntExpr* const expr, int64_t value) {
7141 CHECK_EQ(this, expr->solver());
7142 if (value <= expr->Min()) {
7143 return MakeIntConst(value);
7144 }
7145 if (expr->Bound()) {
7146 return MakeIntConst(std::min(expr->Min(), value));
7147 }
7148 if (expr->Max() <= value) {
7149 return expr;
7150 }
7151 return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7152}
7153
7154IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7155 return MakeMin(expr, static_cast<int64_t>(value));
7156}
7157
7158IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7159 CHECK_EQ(this, left->solver());
7160 CHECK_EQ(this, right->solver());
7161 if (left->Bound()) {
7162 return MakeMax(right, left->Min());
7163 }
7164 if (right->Bound()) {
7165 return MakeMax(left, right->Min());
7166 }
7167 if (left->Min() >= right->Max()) {
7168 return left;
7169 }
7170 if (right->Min() >= left->Max()) {
7171 return right;
7172 }
7173 return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7174}
7175
7176IntExpr* Solver::MakeMax(IntExpr* const expr, int64_t value) {
7177 CHECK_EQ(this, expr->solver());
7178 if (expr->Bound()) {
7179 return MakeIntConst(std::max(expr->Min(), value));
7180 }
7181 if (value <= expr->Min()) {
7182 return expr;
7183 }
7184 if (expr->Max() <= value) {
7185 return MakeIntConst(value);
7186 }
7187 return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7188}
7189
7190IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7191 return MakeMax(expr, static_cast<int64_t>(value));
7192}
7193
7195 int64_t early_date, int64_t late_date,
7196 int64_t late_cost) {
7197 return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7198 this, expr, early_cost, early_date, late_date, late_cost)));
7199}
7200
7202 int64_t fixed_charge, int64_t step) {
7203 if (step == 0) {
7204 if (fixed_charge == 0) {
7205 return MakeIntConst(int64_t{0});
7206 } else {
7207 return RegisterIntExpr(
7208 RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7209 }
7210 } else if (step == 1) {
7211 return RegisterIntExpr(
7212 RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7213 } else {
7214 return RegisterIntExpr(
7215 RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7216 }
7217 // TODO(user) : benchmark with virtualization of
7218 // PosIntDivDown and PosIntDivUp - or function pointers.
7219}
7220
7221// ----- Piecewise Linear -----
7222
7224 public:
7226 const PiecewiseLinearFunction& f)
7227 : BaseIntExpr(solver), expr_(expr), f_(f) {}
7229 int64_t Min() const override {
7230 return f_.GetMinimum(expr_->Min(), expr_->Max());
7231 }
7232 void SetMin(int64_t m) override {
7233 const auto& range =
7234 f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7235 expr_->SetRange(range.first, range.second);
7236 }
7237
7238 int64_t Max() const override {
7239 return f_.GetMaximum(expr_->Min(), expr_->Max());
7240 }
7241
7242 void SetMax(int64_t m) override {
7243 const auto& range =
7244 f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7245 expr_->SetRange(range.first, range.second);
7246 }
7247
7248 void SetRange(int64_t l, int64_t u) override {
7249 const auto& range =
7250 f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7251 expr_->SetRange(range.first, range.second);
7252 }
7253 std::string name() const override {
7254 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7255 f_.DebugString());
7256 }
7257
7258 std::string DebugString() const override {
7259 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7260 f_.DebugString());
7261 }
7262
7263 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7264
7265 void Accept(ModelVisitor* const visitor) const override {
7266 // TODO(user): Implement visitor.
7267 }
7268
7269 private:
7270 IntExpr* const expr_;
7271 const PiecewiseLinearFunction f_;
7272};
7273
7278
7279// ----- Conditional Expression -----
7280
7282 IntExpr* const expr,
7283 int64_t unperformed_value) {
7284 if (condition->Min() == 1) {
7285 return expr;
7286 } else if (condition->Max() == 0) {
7287 return MakeIntConst(unperformed_value);
7288 } else {
7289 IntExpr* cache = Cache()->FindExprExprConstantExpression(
7290 condition, expr, unperformed_value,
7292 if (cache == nullptr) {
7293 cache = RevAlloc(
7294 new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7295 Cache()->InsertExprExprConstantExpression(
7296 cache, condition, expr, unperformed_value,
7298 }
7299 return cache;
7300 }
7301}
7302
7303// ----- Modulo -----
7304
7305IntExpr* Solver::MakeModulo(IntExpr* const x, int64_t mod) {
7306 IntVar* const result =
7307 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7308 if (mod >= 0) {
7309 AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7310 } else {
7311 AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7312 }
7313 return result;
7314}
7315
7317 if (mod->Bound()) {
7318 return MakeModulo(x, mod->Min());
7319 }
7320 IntVar* const result =
7321 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7322 AddConstraint(MakeLess(result, MakeAbs(mod)));
7324 return result;
7325}
7326
7327// --------- IntVar ---------
7328
7329int IntVar::VarType() const { return UNSPECIFIED; }
7330
7331void IntVar::RemoveValues(const std::vector<int64_t>& values) {
7332 // TODO(user): Check and maybe inline this code.
7333 const int size = values.size();
7334 DCHECK_GE(size, 0);
7335 switch (size) {
7336 case 0: {
7337 return;
7338 }
7339 case 1: {
7340 RemoveValue(values[0]);
7341 return;
7342 }
7343 case 2: {
7344 RemoveValue(values[0]);
7345 RemoveValue(values[1]);
7346 return;
7347 }
7348 case 3: {
7349 RemoveValue(values[0]);
7350 RemoveValue(values[1]);
7351 RemoveValue(values[2]);
7352 return;
7353 }
7354 default: {
7355 // 4 values, let's start doing some more clever things.
7356 // TODO(user) : Sort values!
7357 int start_index = 0;
7358 int64_t new_min = Min();
7359 if (values[start_index] <= new_min) {
7360 while (start_index < size - 1 &&
7361 values[start_index + 1] == values[start_index] + 1) {
7362 new_min = values[start_index + 1] + 1;
7363 start_index++;
7364 }
7365 }
7366 int end_index = size - 1;
7367 int64_t new_max = Max();
7368 if (values[end_index] >= new_max) {
7369 while (end_index > start_index + 1 &&
7370 values[end_index - 1] == values[end_index] - 1) {
7371 new_max = values[end_index - 1] - 1;
7372 end_index--;
7373 }
7374 }
7375 SetRange(new_min, new_max);
7376 for (int i = start_index; i <= end_index; ++i) {
7377 RemoveValue(values[i]);
7378 }
7379 }
7380 }
7381}
7382
7383void IntVar::Accept(ModelVisitor* const visitor) const {
7384 IntExpr* const casted = solver()->CastExpression(this);
7385 visitor->VisitIntegerVariable(this, casted);
7386}
7387
7388void IntVar::SetValues(const std::vector<int64_t>& values) {
7389 switch (values.size()) {
7390 case 0: {
7391 solver()->Fail();
7392 break;
7393 }
7394 case 1: {
7395 SetValue(values.back());
7396 break;
7397 }
7398 case 2: {
7399 if (Contains(values[0])) {
7400 if (Contains(values[1])) {
7401 const int64_t l = std::min(values[0], values[1]);
7402 const int64_t u = std::max(values[0], values[1]);
7403 SetRange(l, u);
7404 if (u > l + 1) {
7405 RemoveInterval(l + 1, u - 1);
7406 }
7407 } else {
7408 SetValue(values[0]);
7409 }
7410 } else {
7411 SetValue(values[1]);
7412 }
7413 break;
7414 }
7415 default: {
7416 // TODO(user): use a clean and safe SortedUniqueCopy() class
7417 // that uses a global, static shared (and locked) storage.
7418 // TODO(user): [optional] consider porting
7419 // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7420 // existing base/stl_util.h and using it here.
7421 // TODO(user): We could filter out values not in the var.
7422 std::vector<int64_t>& tmp = solver()->tmp_vector_;
7423 tmp.clear();
7424 tmp.insert(tmp.end(), values.begin(), values.end());
7425 std::sort(tmp.begin(), tmp.end());
7426 tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7427 const int size = tmp.size();
7428 const int64_t vmin = Min();
7429 const int64_t vmax = Max();
7430 int first = 0;
7431 int last = size - 1;
7432 if (tmp.front() > vmax || tmp.back() < vmin) {
7433 solver()->Fail();
7434 }
7435 // TODO(user) : We could find the first position >= vmin by dichotomy.
7436 while (tmp[first] < vmin || !Contains(tmp[first])) {
7437 ++first;
7438 if (first > last || tmp[first] > vmax) {
7439 solver()->Fail();
7440 }
7441 }
7442 while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7443 // Note that last >= first implies tmp[last] >= vmin.
7444 --last;
7445 }
7446 DCHECK_GE(last, first);
7447 SetRange(tmp[first], tmp[last]);
7448 while (first < last) {
7449 const int64_t start = tmp[first] + 1;
7450 const int64_t end = tmp[first + 1] - 1;
7451 if (start <= end) {
7452 RemoveInterval(start, end);
7453 }
7454 first++;
7455 }
7456 }
7457 }
7458}
7459// ---------- BaseIntExpr ---------
7460
7461void LinkVarExpr(Solver* s, IntExpr* expr, IntVar* var) {
7462 if (!var->Bound()) {
7463 if (var->VarType() == DOMAIN_INT_VAR) {
7464 DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7466 s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7467 } else {
7468 s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7469 expr);
7470 }
7471 }
7472}
7473
7475 if (var_ == nullptr) {
7476 solver()->SaveValue(reinterpret_cast<void**>(&var_));
7477 var_ = CastToVar();
7478 }
7479 return var_;
7480}
7481
7483 int64_t vmin, vmax;
7484 Range(&vmin, &vmax);
7485 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7486 LinkVarExpr(solver(), this, var);
7487 return var;
7488}
7489
7490// Discovery methods
7491bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7492 IntExpr** const right) {
7493 if (expr->IsVar()) {
7494 IntVar* const expr_var = expr->Var();
7495 expr = CastExpression(expr_var);
7496 }
7497 // This is a dynamic cast to check the type of expr.
7498 // It returns nullptr is expr is not a subclass of SubIntExpr.
7499 SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7500 if (sub_expr != nullptr) {
7501 *left = sub_expr->left();
7502 *right = sub_expr->right();
7503 return true;
7504 }
7505 return false;
7506}
7507
7508bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7509 bool* is_negated) const {
7510 if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7511 *inner_var = expr->Var();
7512 *is_negated = false;
7513 return true;
7514 } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7515 SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7516 if (sub_var != nullptr && sub_var->Constant() == 1 &&
7517 sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7518 *is_negated = true;
7519 *inner_var = sub_var->SubVar();
7520 return true;
7521 }
7522 }
7523 return false;
7524}
7525
7526bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7527 int64_t* coefficient) {
7528 if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7529 TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7530 *coefficient = var->Constant();
7531 *inner_expr = var->SubVar();
7532 return true;
7533 } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7534 TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7535 *coefficient = prod->Constant();
7536 *inner_expr = prod->Expr();
7537 return true;
7538 }
7539 *inner_expr = expr;
7540 *coefficient = 1;
7541 return false;
7542}
7543
7544} // namespace operations_research
IntVar * Var() override
Creates a variable from the expression.
bool Contains(int64_t v) const override
IntVarIterator * MakeDomainIterator(bool reversible) const override
IntVar * IsGreaterOrEqual(int64_t constant) override
IntVarIterator * MakeHoleIterator(bool reversible) const override
--— Misc --—
IntVar * IsEqual(int64_t constant) override
IsEqual.
void RemoveValue(int64_t v) override
This method removes the value 'v' from the domain of the variable.
IntVar * IsDifferent(int64_t constant) override
SimpleRevFIFO< Demon * > delayed_bound_demons_
void SetMax(int64_t m) override
void SetRange(int64_t mi, int64_t ma) override
This method sets both the min and the max of the expression.
void RemoveInterval(int64_t l, int64_t u) override
void SetMin(int64_t m) override
SimpleRevFIFO< Demon * > bound_demons_
uint64_t Size() const override
This method returns the number of values in the domain of the variable.
static const int kUnboundBooleanVarValue
--— Boolean variable --—
std::string DebugString() const override
void WhenBound(Demon *d) override
IntVar * IsLessOrEqual(int64_t constant) override
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
virtual Solver::DemonPriority priority() const
---------------— Demon class -------------—
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetMax(int64_t m)=0
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual int64_t Min() const =0
virtual void SetMin(int64_t m)=0
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
virtual void Range(int64_t *l, int64_t *u)
IntVar * VarWithName(const std::string &name)
-------— IntExpr -------—
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
virtual void WhenBound(Demon *d)=0
virtual void WhenDomain(Demon *d)=0
virtual void SetValues(const std::vector< int64_t > &values)
This method intersects the current domain with the values in the array.
void Accept(ModelVisitor *visitor) const override
Accepts the given visitor.
virtual IntVar * IsDifferent(int64_t constant)=0
IntVar(Solver *s)
-------— IntVar -------—
virtual int64_t OldMax() const =0
Returns the previous max.
virtual IntVar * IsLessOrEqual(int64_t constant)=0
virtual bool Contains(int64_t v) const =0
virtual int64_t Value() const =0
virtual int VarType() const
------— IntVar ------—
virtual void RemoveValue(int64_t v)=0
This method removes the value 'v' from the domain of the variable.
virtual uint64_t Size() const =0
This method returns the number of values in the domain of the variable.
virtual IntVar * IsGreaterOrEqual(int64_t constant)=0
virtual IntVar * IsEqual(int64_t constant)=0
IsEqual.
virtual void RemoveInterval(int64_t l, int64_t u)=0
virtual int64_t OldMin() const =0
Returns the previous min.
virtual void RemoveValues(const std::vector< int64_t > &values)
This method remove the values from the domain of the variable.
static int64_t GCD64(int64_t x, int64_t y)
Definition mathutil.h:108
virtual void VisitIntegerVariable(const IntVar *variable, IntExpr *delegate)
void SetRange(int64_t l, int64_t u) override
This method sets both the min and the max of the expression.
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
std::string name() const override
Object naming.
std::string DebugString() const override
virtual std::string name() const
Object naming.
For the time being, Solver is neither MT_SAFE nor MT_HOT.
IntExpr * MakeDiv(IntExpr *expr, int64_t value)
expr / value (integer division)
IntVar * MakeBoolVar(const std::string &name)
MakeBoolVar will create a variable with a {0, 1} domain.
Constraint * MakeNonEquality(IntExpr *left, IntExpr *right)
left != right
Definition range_cst.cc:570
IntExpr * MakeMax(const std::vector< IntVar * > &vars)
std::max(vars)
IntExpr * MakeDifference(IntExpr *left, IntExpr *right)
left - right
IntVar * MakeBoolVar()
MakeBoolVar will create a variable with a {0, 1} domain.
IntExpr * MakeSum(IntExpr *left, IntExpr *right)
--— Integer Expressions --—
IntExpr * MakeMin(const std::vector< IntVar * > &vars)
std::min(vars)
void Fail()
Abandon the current branch in the search tree. A backtrack will follow.
@ VAR_PRIORITY
VAR_PRIORITY is between DELAYED_PRIORITY and NORMAL_PRIORITY.
@ OUTSIDE_SEARCH
Before search, after search.
IntVar * RegisterIntVar(IntVar *var)
Registers a new IntVar and wraps it inside a TraceIntVar if necessary.
Definition trace.cc:860
bool IsBooleanVar(IntExpr *expr, IntVar **inner_var, bool *is_negated) const
Constraint * MakeLess(IntExpr *left, IntExpr *right)
left < right
Definition range_cst.cc:552
ModelCache * Cache() const
Returns the cache of the model.
IntExpr * MakeOpposite(IntExpr *expr)
-expr
Constraint * MakeAbsEquality(IntVar *var, IntVar *abs_var)
Creates the constraint abs(var) == abs_var.
void MakeBoolVarArray(int var_count, const std::string &name, std::vector< IntVar * > *vars)
IntExpr * MakePiecewiseLinearExpr(IntExpr *expr, const PiecewiseLinearFunction &f)
IntExpr * RegisterIntExpr(IntExpr *expr)
Registers a new IntExpr and wraps it inside a TraceIntExpr if necessary.
Definition trace.cc:848
IntExpr * MakeModulo(IntExpr *x, int64_t mod)
Modulo expression x % mod (with the python convention for modulo).
Constraint * MakeGreater(IntExpr *left, IntExpr *right)
left > right
Definition range_cst.cc:566
IntExpr * MakeAbs(IntExpr *expr)
expr
IntExpr * MakeSquare(IntExpr *expr)
expr * expr
Constraint * MakeBetweenCt(IntExpr *expr, int64_t l, int64_t u)
--— Between and related constraints --—
Definition expr_cst.cc:929
void MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax, const std::string &name, std::vector< IntVar * > *vars)
IntVar * MakeIntVar(int64_t min, int64_t max, const std::string &name)
--— Int Variables and Constants --—
bool IsProduct(IntExpr *expr, IntExpr **inner_expr, int64_t *coefficient)
IntExpr * MakeConvexPiecewiseExpr(IntExpr *expr, int64_t early_cost, int64_t early_date, int64_t late_date, int64_t late_cost)
Convex piecewise function.
IntExpr * MakePower(IntExpr *expr, int64_t n)
expr ^ n (n > 0)
IntExpr * MakeConditionalExpression(IntVar *condition, IntExpr *expr, int64_t unperformed_value)
Conditional Expr condition ? expr : unperformed_value.
IntExpr * MakeProd(IntExpr *left, IntExpr *right)
left * right
IntVar * MakeIntConst(int64_t val, const std::string &name)
IntConst will create a constant expression.
IntExpr * MakeSemiContinuousExpr(IntExpr *expr, int64_t fixed_charge, int64_t step)
void AddCastConstraint(CastConstraint *constraint, IntVar *target_var, IntExpr *expr)
ABSL_FLAG(bool, cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition map_util.h:94
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition stl_util.h:58
For infeasible and unbounded see Not checked if options check_solutions_if_inf_or_unbounded and the If options first_solution_only is false
problem is infeasible or unbounded (default).
Definition matchers.h:468
std::pair< double, double > Range
A range of values, first is the minimum, second is the maximum.
Definition statistics.h:27
dual_gradient T(y - `dual_solution`) class DiagonalTrustRegionProblemFromQp
std::function< int64_t(const Model &)> Value(IntegerVariable v)
This checks that the variable is fixed.
Definition integer.h:1559
In SWIG mode, we don't want anything besides these top-level includes.
int64_t SubOverflows(int64_t x, int64_t y)
int64_t UnsafeMostSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
int64_t CapAdd(int64_t x, int64_t y)
void RestoreBoolValue(IntVar *var)
--— Trail --—
int64_t UnsafeLeastSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
int64_t CapSub(int64_t x, int64_t y)
Constraint * SetIsGreaterOrEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
Constraint * SetIsEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
bool AddOverflows(int64_t x, int64_t y)
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
--— Exported Methods for Unit Tests --—
static const uint64_t kAllBits64
Basic bit operations.
Definition bitset.h:37
void LinkVarExpr(Solver *s, IntExpr *expr, IntVar *var)
--— IntExprElement --—
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
int64_t CapProd(int64_t x, int64_t y)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
--— Vector manipulations --—
Definition utilities.cc:829
uint64_t OneRange64(uint64_t s, uint64_t e)
Returns a word with bits from s to e set.
Definition bitset.h:289
uint32_t BitPos64(uint64_t pos)
Bit operators used to manipulates bitsets.
Definition bitset.h:334
uint64_t BitCountRange64(const uint64_t *bitset, uint64_t start, uint64_t end)
Returns the number of bits set in bitset between positions start and end.
uint64_t BitCount64(uint64_t n)
Returns the number of bits set in n.
Definition bitset.h:46
bool IsBitSet64(const uint64_t *const bitset, uint64_t pos)
Returns true if the bit pos is set in bitset.
Definition bitset.h:350
uint64_t OneBit64(int pos)
Returns a word with only bit pos set.
Definition bitset.h:42
uint64_t BitOffset64(uint64_t pos)
Returns the word number corresponding to bit number pos.
Definition bitset.h:338
int64_t PosIntDivDown(int64_t e, int64_t v)
uint64_t BitLength64(uint64_t size)
Returns the number of words needed to store size bits.
Definition bitset.h:342
int LeastSignificantBitPosition64(uint64_t n)
Definition bitset.h:131
void CleanVariableOnFail(IntVar *var)
---------------— Queue class ---------------—
int64_t CapOpp(int64_t v)
Note(user): -kint64min != kint64max, but kint64max == ~kint64min.
int MostSignificantBitPosition64(uint64_t n)
Definition bitset.h:235
int64_t PosIntDivUp(int64_t e, int64_t v)
STL namespace.
false true
Definition numbers.cc:228
trees with all degrees equal w the current value of w
bool Next()