Google OR-Tools v9.15
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
expressions.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cmath>
16#include <cstdint>
17#include <cstdlib>
18#include <limits>
19#include <memory>
20#include <string>
21#include <utility>
22#include <vector>
23
24#include "absl/flags/flag.h"
25#include "absl/strings/str_cat.h"
26#include "absl/strings/str_format.h"
27#include "absl/strings/string_view.h"
28#include "absl/types/span.h"
34#include "ortools/util/bitset.h"
37
38ABSL_FLAG(bool, cp_disable_expression_optimization, false,
39 "Disable special optimization when creating expressions.");
40ABSL_FLAG(bool, cp_share_int_consts, true,
41 "Share IntConst's with the same value.");
42
43#if defined(_MSC_VER)
44#pragma warning(disable : 4351 4355)
45#endif
46
47namespace operations_research {
48
49// ---------- IntExpr ----------
50
51IntVar* IntExpr::VarWithName(const std::string& name) {
52 IntVar* const var = Var();
53 var->set_name(name);
54 return var;
55}
56
57// ---------- IntVar ----------
58
59IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
60
61IntVar::IntVar(Solver* const s, const std::string& name)
62 : IntExpr(s), index_(s->GetNewIntVarIndex()) {
64}
65
66// ----- Boolean variable -----
67
69
70void BooleanVar::SetMin(int64_t m) {
71 if (m <= 0) return;
72 if (m > 1) solver()->Fail();
73 SetValue(1);
74}
75
76void BooleanVar::SetMax(int64_t m) {
77 if (m >= 1) return;
78 if (m < 0) solver()->Fail();
79 SetValue(0);
80}
81
82void BooleanVar::SetRange(int64_t mi, int64_t ma) {
83 if (mi > 1 || ma < 0 || mi > ma) {
84 solver()->Fail();
85 }
86 if (mi == 1) {
87 SetValue(1);
88 } else if (ma == 0) {
89 SetValue(0);
90 }
91}
92
93void BooleanVar::RemoveValue(int64_t v) {
95 if (v == 0) {
96 SetValue(1);
97 } else if (v == 1) {
98 SetValue(0);
99 }
100 } else if (v == value_) {
101 solver()->Fail();
102 }
103}
104
105void BooleanVar::RemoveInterval(int64_t l, int64_t u) {
106 if (u < l) return;
107 if (l <= 0 && u >= 1) {
108 solver()->Fail();
109 } else if (l == 1) {
110 SetValue(0);
111 } else if (u == 0) {
112 SetValue(1);
113 }
114}
115
119 delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
120 } else {
121 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
122 }
123 }
124}
125
126uint64_t BooleanVar::Size() const {
127 return (1 + (value_ == kUnboundBooleanVarValue));
128}
129
130bool BooleanVar::Contains(int64_t v) const {
131 return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
132}
133
134IntVar* BooleanVar::IsEqual(int64_t constant) {
135 if (constant > 1 || constant < 0) {
136 return solver()->MakeIntConst(0);
137 }
138 if (constant == 1) {
139 return this;
140 } else { // constant == 0.
141 return solver()->MakeDifference(1, this)->Var();
142 }
143}
144
145IntVar* BooleanVar::IsDifferent(int64_t constant) {
146 if (constant > 1 || constant < 0) {
147 return solver()->MakeIntConst(1);
148 }
149 if (constant == 1) {
150 return solver()->MakeDifference(1, this)->Var();
151 } else { // constant == 0.
152 return this;
153 }
154}
155
157 if (constant > 1) {
158 return solver()->MakeIntConst(0);
159 } else if (constant <= 0) {
160 return solver()->MakeIntConst(1);
161 } else {
162 return this;
163 }
164}
165
167 if (constant < 0) {
168 return solver()->MakeIntConst(0);
169 } else if (constant >= 1) {
170 return solver()->MakeIntConst(1);
171 } else {
172 return IsEqual(0);
173 }
174}
175
176std::string BooleanVar::DebugString() const {
177 std::string out;
178 const std::string& var_name = name();
179 if (!var_name.empty()) {
180 out = var_name + "(";
181 } else {
182 out = "BooleanVar(";
183 }
184 switch (value_) {
185 case 0:
186 out += "0";
187 break;
188 case 1:
189 out += "1";
190 break;
192 out += "0 .. 1";
193 break;
194 }
195 out += ")";
196 return out;
197}
198
199namespace {
200// ---------- Subclasses of IntVar ----------
201
202// ----- Domain Int Var: base class for variables -----
203// It Contains bounds and a bitset representation of possible values.
204class DomainIntVar : public IntVar {
205 public:
206 // Utility classes
207 class BitSetIterator : public BaseObject {
208 public:
209 BitSetIterator(uint64_t* const bitset, int64_t omin)
210 : bitset_(bitset),
211 omin_(omin),
212 max_(std::numeric_limits<int64_t>::min()),
213 current_(std::numeric_limits<int64_t>::max()) {}
214
215 ~BitSetIterator() override {}
216
217 void Init(int64_t min, int64_t max) {
218 max_ = max;
219 current_ = min;
220 }
221
222 bool Ok() const { return current_ <= max_; }
223
224 int64_t Value() const { return current_; }
225
226 void Next() {
227 if (++current_ <= max_) {
229 bitset_, current_ - omin_, max_ - omin_) +
230 omin_;
231 }
232 }
233
234 std::string DebugString() const override { return "BitSetIterator"; }
235
236 private:
237 uint64_t* const bitset_;
238 const int64_t omin_;
239 int64_t max_;
240 int64_t current_;
241 };
242
243 class BitSet : public BaseObject {
244 public:
245 explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
246 ~BitSet() override {}
247
248 virtual int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) = 0;
249 virtual int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) = 0;
250 virtual bool Contains(int64_t val) const = 0;
251 virtual bool SetValue(int64_t val) = 0;
252 virtual bool RemoveValue(int64_t val) = 0;
253 virtual uint64_t Size() const = 0;
254 virtual void DelayRemoveValue(int64_t val) = 0;
255 virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
256 virtual void ClearRemovedValues() = 0;
257 virtual std::string pretty_DebugString(int64_t min, int64_t max) const = 0;
258 virtual BitSetIterator* MakeIterator() = 0;
259
260 void InitHoles() {
261 const uint64_t current_stamp = solver_->stamp();
262 if (holes_stamp_ < current_stamp) {
263 holes_.clear();
264 holes_stamp_ = current_stamp;
265 }
266 }
267
268 virtual void ClearHoles() { holes_.clear(); }
269
270 const std::vector<int64_t>& Holes() { return holes_; }
271
272 void AddHole(int64_t value) { holes_.push_back(value); }
273
274 int NumHoles() const {
275 return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
276 }
277
278 protected:
279 Solver* const solver_;
280
281 private:
282 std::vector<int64_t> holes_;
283 uint64_t holes_stamp_;
284 };
285
286 class QueueHandler : public Demon {
287 public:
288 explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
289 ~QueueHandler() override {}
290 void Run(Solver* const s) override {
291 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
292 var_->Process();
293 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
294 }
295 Solver::DemonPriority priority() const override {
297 }
298 std::string DebugString() const override {
299 return absl::StrFormat("Handler(%s)", var_->DebugString());
300 }
301
302 private:
303 DomainIntVar* const var_;
304 };
305
306 // Bounds and Value watchers
307
308 // This class stores the watchers variables attached to values. It is
309 // reversible and it helps maintaining the set of 'active' watchers
310 // (variables not bound to a single value).
311 template <class T>
312 class RevIntPtrMap {
313 public:
314 RevIntPtrMap(Solver* const solver, int64_t rmin, int64_t rmax)
315 : solver_(solver), range_min_(rmin), start_(0) {}
316
317 ~RevIntPtrMap() {}
318
319 bool Empty() const { return start_.Value() == elements_.size(); }
320
321 void SortActive() { std::sort(elements_.begin(), elements_.end()); }
322
323 // Access with value API.
324
325 // Add the pointer to the map attached to the given value.
326 void UnsafeRevInsert(int64_t value, T* elem) {
327 elements_.push_back(std::make_pair(value, elem));
328 if (solver_->state() != Solver::OUTSIDE_SEARCH) {
329 solver_->AddBacktrackAction(
330 [this, value](Solver* s) { Uninsert(value); }, false);
331 }
332 }
333
334 T* FindPtrOrNull(int64_t value, int* position) {
335 for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
336 if (elements_[pos].first == value) {
337 if (position != nullptr) *position = pos;
338 return At(pos).second;
339 }
340 }
341 return nullptr;
342 }
343
344 // Access map through the underlying vector.
345 void RemoveAt(int position) {
346 const int start = start_.Value();
347 DCHECK_GE(position, start);
348 DCHECK_LT(position, elements_.size());
349 if (position > start) {
350 // Swap the current element with the one at the start position, and
351 // increase start.
352 const std::pair<int64_t, T*> copy = elements_[start];
353 elements_[start] = elements_[position];
354 elements_[position] = copy;
355 }
356 start_.Incr(solver_);
357 }
358
359 const std::pair<int64_t, T*>& At(int position) const {
360 DCHECK_GE(position, start_.Value());
361 DCHECK_LT(position, elements_.size());
362 return elements_[position];
363 }
364
365 void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
366
367 int start() const { return start_.Value(); }
368 int end() const { return elements_.size(); }
369 // Number of active elements.
370 int Size() const { return elements_.size() - start_.Value(); }
371
372 // Removes the object permanently from the map.
373 void Uninsert(int64_t value) {
374 for (int pos = 0; pos < elements_.size(); ++pos) {
375 if (elements_[pos].first == value) {
376 DCHECK_GE(pos, start_.Value());
377 const int last = elements_.size() - 1;
378 if (pos != last) { // Swap the current with the last.
379 elements_[pos] = elements_.back();
380 }
381 elements_.pop_back();
382 return;
383 }
384 }
385 LOG(FATAL) << "The element should have been removed";
386 }
387
388 private:
389 Solver* const solver_;
390 const int64_t range_min_;
391 NumericalRev<int> start_;
392 std::vector<std::pair<int64_t, T*>> elements_;
393 };
394
395 // Base class for value watchers
396 class BaseValueWatcher : public Constraint {
397 public:
398 explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
399
400 ~BaseValueWatcher() override {}
401
402 virtual IntVar* GetOrMakeValueWatcher(int64_t value) = 0;
403
404 virtual void SetValueWatcher(IntVar* boolvar, int64_t value) = 0;
405 };
406
407 // This class monitors the domain of the variable and updates the
408 // IsEqual/IsDifferent boolean variables accordingly.
409 class ValueWatcher : public BaseValueWatcher {
410 public:
411 class WatchDemon : public Demon {
412 public:
413 WatchDemon(ValueWatcher* const watcher, int64_t value, IntVar* var)
414 : value_watcher_(watcher), value_(value), var_(var) {}
415 ~WatchDemon() override {}
416
417 void Run(Solver* const solver) override {
418 value_watcher_->ProcessValueWatcher(value_, var_);
419 }
420
421 private:
422 ValueWatcher* const value_watcher_;
423 const int64_t value_;
424 IntVar* const var_;
425 };
426
427 class VarDemon : public Demon {
428 public:
429 explicit VarDemon(ValueWatcher* const watcher)
430 : value_watcher_(watcher) {}
431
432 ~VarDemon() override {}
433
434 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
435
436 private:
437 ValueWatcher* const value_watcher_;
438 };
439
440 ValueWatcher(Solver* const solver, DomainIntVar* const variable)
441 : BaseValueWatcher(solver),
442 variable_(variable),
443 hole_iterator_(variable_->MakeHoleIterator(true)),
444 var_demon_(nullptr),
445 watchers_(solver, variable->Min(), variable->Max()) {}
446
447 ~ValueWatcher() override {}
448
449 IntVar* GetOrMakeValueWatcher(int64_t value) override {
450 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
451 if (watcher != nullptr) return watcher;
452 if (variable_->Contains(value)) {
453 if (variable_->Bound()) {
454 return solver()->MakeIntConst(1);
455 } else {
456 const std::string vname = variable_->HasName()
457 ? variable_->name()
458 : variable_->DebugString();
459 const std::string bname =
460 absl::StrFormat("Watch<%s == %d>", vname, value);
461 IntVar* const boolvar = solver()->MakeBoolVar(bname);
462 watchers_.UnsafeRevInsert(value, boolvar);
463 if (posted_.Switched()) {
464 boolvar->WhenBound(
465 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
466 var_demon_->desinhibit(solver());
467 }
468 return boolvar;
469 }
470 } else {
471 return variable_->solver()->MakeIntConst(0);
472 }
473 }
474
475 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
476 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
477 if (!boolvar->Bound()) {
478 watchers_.UnsafeRevInsert(value, boolvar);
479 if (posted_.Switched() && !boolvar->Bound()) {
480 boolvar->WhenBound(
481 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
482 var_demon_->desinhibit(solver());
483 }
484 }
485 }
486
487 void Post() override {
488 var_demon_ = solver()->RevAlloc(new VarDemon(this));
489 variable_->WhenDomain(var_demon_);
490 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
491 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
492 const int64_t value = w.first;
493 IntVar* const boolvar = w.second;
494 if (!boolvar->Bound() && variable_->Contains(value)) {
495 boolvar->WhenBound(
496 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
497 }
498 }
499 posted_.Switch(solver());
500 }
501
502 void InitialPropagate() override {
503 if (variable_->Bound()) {
504 VariableBound();
505 } else {
506 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
507 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
508 const int64_t value = w.first;
509 IntVar* const boolvar = w.second;
510 if (!variable_->Contains(value)) {
511 boolvar->SetValue(0);
512 watchers_.RemoveAt(pos);
513 } else {
514 if (boolvar->Bound()) {
515 ProcessValueWatcher(value, boolvar);
516 watchers_.RemoveAt(pos);
517 }
518 }
519 }
520 CheckInhibit();
521 }
522 }
523
524 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
525 if (boolvar->Min() == 0) {
526 if (variable_->Size() < 0xFFFFFF) {
527 variable_->RemoveValue(value);
528 } else {
529 // Delay removal.
530 solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
531 }
532 } else {
533 variable_->SetValue(value);
534 }
535 }
536
537 void ProcessVar() {
538 const int kSmallList = 16;
539 if (variable_->Bound()) {
540 VariableBound();
541 } else if (watchers_.Size() <= kSmallList ||
542 variable_->Min() != variable_->OldMin() ||
543 variable_->Max() != variable_->OldMax()) {
544 // Brute force loop for small numbers of watchers, or if the bounds have
545 // changed, which would have required a sort (n log(n)) anyway to take
546 // advantage of.
547 ScanWatchers();
548 CheckInhibit();
549 } else {
550 // If there is no bitset, then there are no holes.
551 // In that case, the two loops above should have performed all
552 // propagation. Otherwise, scan the remaining watchers.
553 BitSet* const bitset = variable_->bitset();
554 if (bitset != nullptr && !watchers_.Empty()) {
555 if (bitset->NumHoles() * 2 < watchers_.Size()) {
556 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
557 int pos = 0;
558 IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
559 if (boolvar != nullptr) {
560 boolvar->SetValue(0);
561 watchers_.RemoveAt(pos);
562 }
563 }
564 } else {
565 ScanWatchers();
566 }
567 }
568 CheckInhibit();
569 }
570 }
571
572 // Optimized case if the variable is bound.
573 void VariableBound() {
574 DCHECK(variable_->Bound());
575 const int64_t value = variable_->Min();
576 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
577 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
578 w.second->SetValue(w.first == value);
579 }
580 watchers_.RemoveAll();
581 var_demon_->inhibit(solver());
582 }
583
584 // Scans all the watchers to check and assign them.
585 void ScanWatchers() {
586 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
587 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
588 if (!variable_->Contains(w.first)) {
589 IntVar* const boolvar = w.second;
590 boolvar->SetValue(0);
591 watchers_.RemoveAt(pos);
592 }
593 }
594 }
595
596 // If the set of active watchers is empty, we can inhibit the demon on the
597 // main variable.
598 void CheckInhibit() {
599 if (watchers_.Empty()) {
600 var_demon_->inhibit(solver());
601 }
602 }
603
604 void Accept(ModelVisitor* const visitor) const override {
605 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
606 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
607 variable_);
608 std::vector<int64_t> all_coefficients;
609 std::vector<IntVar*> all_bool_vars;
610 for (int position = watchers_.start(); position < watchers_.end();
611 ++position) {
612 const std::pair<int64_t, IntVar*>& w = watchers_.At(position);
613 all_coefficients.push_back(w.first);
614 all_bool_vars.push_back(w.second);
615 }
616 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
617 all_bool_vars);
618 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
619 all_coefficients);
620 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
621 }
622
623 std::string DebugString() const override {
624 return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
625 }
626
627 private:
628 DomainIntVar* const variable_;
629 IntVarIterator* const hole_iterator_;
630 RevSwitch posted_;
631 Demon* var_demon_;
632 RevIntPtrMap<IntVar> watchers_;
633 };
634
635 // Optimized case for small maps.
636 class DenseValueWatcher : public BaseValueWatcher {
637 public:
638 class WatchDemon : public Demon {
639 public:
640 WatchDemon(DenseValueWatcher* const watcher, int64_t value, IntVar* var)
641 : value_watcher_(watcher), value_(value), var_(var) {}
642 ~WatchDemon() override {}
643
644 void Run(Solver* const solver) override {
645 value_watcher_->ProcessValueWatcher(value_, var_);
646 }
647
648 private:
649 DenseValueWatcher* const value_watcher_;
650 const int64_t value_;
651 IntVar* const var_;
652 };
653
654 class VarDemon : public Demon {
655 public:
656 explicit VarDemon(DenseValueWatcher* const watcher)
657 : value_watcher_(watcher) {}
658
659 ~VarDemon() override {}
660
661 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
662
663 private:
664 DenseValueWatcher* const value_watcher_;
665 };
666
667 DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
668 : BaseValueWatcher(solver),
669 variable_(variable),
670 hole_iterator_(variable_->MakeHoleIterator(true)),
671 var_demon_(nullptr),
672 offset_(variable->Min()),
673 watchers_(variable->Max() - variable->Min() + 1, nullptr),
674 active_watchers_(0) {}
675
676 ~DenseValueWatcher() override {}
677
678 IntVar* GetOrMakeValueWatcher(int64_t value) override {
679 const int64_t var_max = offset_ + watchers_.size() - 1; // Bad cast.
680 if (value < offset_ || value > var_max) {
681 return solver()->MakeIntConst(0);
682 }
683 const int index = value - offset_;
684 IntVar* const watcher = watchers_[index];
685 if (watcher != nullptr) return watcher;
686 if (variable_->Contains(value)) {
687 if (variable_->Bound()) {
688 return solver()->MakeIntConst(1);
689 } else {
690 const std::string vname = variable_->HasName()
691 ? variable_->name()
692 : variable_->DebugString();
693 const std::string bname =
694 absl::StrFormat("Watch<%s == %d>", vname, value);
695 IntVar* const boolvar = solver()->MakeBoolVar(bname);
696 RevInsert(index, boolvar);
697 if (posted_.Switched()) {
698 boolvar->WhenBound(
699 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
700 var_demon_->desinhibit(solver());
701 }
702 return boolvar;
703 }
704 } else {
705 return variable_->solver()->MakeIntConst(0);
706 }
707 }
708
709 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
710 const int index = value - offset_;
711 CHECK(watchers_[index] == nullptr);
712 if (!boolvar->Bound()) {
713 RevInsert(index, boolvar);
714 if (posted_.Switched() && !boolvar->Bound()) {
715 boolvar->WhenBound(
716 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
717 var_demon_->desinhibit(solver());
718 }
719 }
720 }
721
722 void Post() override {
723 var_demon_ = solver()->RevAlloc(new VarDemon(this));
724 variable_->WhenDomain(var_demon_);
725 for (int pos = 0; pos < watchers_.size(); ++pos) {
726 const int64_t value = pos + offset_;
727 IntVar* const boolvar = watchers_[pos];
728 if (boolvar != nullptr && !boolvar->Bound() &&
729 variable_->Contains(value)) {
730 boolvar->WhenBound(
731 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
732 }
733 }
734 posted_.Switch(solver());
735 }
736
737 void InitialPropagate() override {
738 if (variable_->Bound()) {
739 VariableBound();
740 } else {
741 for (int pos = 0; pos < watchers_.size(); ++pos) {
742 IntVar* const boolvar = watchers_[pos];
743 if (boolvar == nullptr) continue;
744 const int64_t value = pos + offset_;
745 if (!variable_->Contains(value)) {
746 boolvar->SetValue(0);
747 RevRemove(pos);
748 } else if (boolvar->Bound()) {
749 ProcessValueWatcher(value, boolvar);
750 RevRemove(pos);
751 }
752 }
753 if (active_watchers_.Value() == 0) {
754 var_demon_->inhibit(solver());
755 }
756 }
757 }
758
759 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
760 if (boolvar->Min() == 0) {
761 variable_->RemoveValue(value);
762 } else {
763 variable_->SetValue(value);
764 }
765 }
766
767 void ProcessVar() {
768 if (variable_->Bound()) {
769 VariableBound();
770 } else {
771 // Brute force loop for small numbers of watchers.
772 ScanWatchers();
773 if (active_watchers_.Value() == 0) {
774 var_demon_->inhibit(solver());
775 }
776 }
777 }
778
779 // Optimized case if the variable is bound.
780 void VariableBound() {
781 DCHECK(variable_->Bound());
782 const int64_t value = variable_->Min();
783 for (int pos = 0; pos < watchers_.size(); ++pos) {
784 IntVar* const boolvar = watchers_[pos];
785 if (boolvar != nullptr) {
786 boolvar->SetValue(pos + offset_ == value);
787 RevRemove(pos);
788 }
789 }
790 var_demon_->inhibit(solver());
791 }
792
793 // Scans all the watchers to check and assign them.
794 void ScanWatchers() {
795 const int64_t old_min_index = variable_->OldMin() - offset_;
796 const int64_t old_max_index = variable_->OldMax() - offset_;
797 const int64_t min_index = variable_->Min() - offset_;
798 const int64_t max_index = variable_->Max() - offset_;
799 for (int pos = old_min_index; pos < min_index; ++pos) {
800 IntVar* const boolvar = watchers_[pos];
801 if (boolvar != nullptr) {
802 boolvar->SetValue(0);
803 RevRemove(pos);
804 }
805 }
806 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
807 IntVar* const boolvar = watchers_[pos];
808 if (boolvar != nullptr) {
809 boolvar->SetValue(0);
810 RevRemove(pos);
811 }
812 }
813 BitSet* const bitset = variable_->bitset();
814 if (bitset != nullptr) {
815 if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
816 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
817 IntVar* const boolvar = watchers_[hole - offset_];
818 if (boolvar != nullptr) {
819 boolvar->SetValue(0);
820 RevRemove(hole - offset_);
821 }
822 }
823 } else {
824 for (int pos = min_index + 1; pos < max_index; ++pos) {
825 IntVar* const boolvar = watchers_[pos];
826 if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
827 boolvar->SetValue(0);
828 RevRemove(pos);
829 }
830 }
831 }
832 }
833 }
834
835 void RevRemove(int pos) {
836 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
837 watchers_[pos] = nullptr;
838 active_watchers_.Decr(solver());
839 }
840
841 void RevInsert(int pos, IntVar* boolvar) {
842 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
843 watchers_[pos] = boolvar;
844 active_watchers_.Incr(solver());
845 }
846
847 void Accept(ModelVisitor* const visitor) const override {
848 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
849 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
850 variable_);
851 std::vector<int64_t> all_coefficients;
852 std::vector<IntVar*> all_bool_vars;
853 for (int position = 0; position < watchers_.size(); ++position) {
854 if (watchers_[position] != nullptr) {
855 all_coefficients.push_back(position + offset_);
856 all_bool_vars.push_back(watchers_[position]);
857 }
858 }
859 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
860 all_bool_vars);
861 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
862 all_coefficients);
863 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
864 }
865
866 std::string DebugString() const override {
867 return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
868 }
869
870 private:
871 DomainIntVar* const variable_;
872 IntVarIterator* const hole_iterator_;
873 RevSwitch posted_;
874 Demon* var_demon_;
875 const int64_t offset_;
876 std::vector<IntVar*> watchers_;
877 NumericalRev<int> active_watchers_;
878 };
879
880 class BaseUpperBoundWatcher : public Constraint {
881 public:
882 explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
883
884 ~BaseUpperBoundWatcher() override {}
885
886 virtual IntVar* GetOrMakeUpperBoundWatcher(int64_t value) = 0;
887
888 virtual void SetUpperBoundWatcher(IntVar* boolvar, int64_t value) = 0;
889 };
890
891 // This class watches the bounds of the variable and updates the
892 // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
893 // accordingly.
894 class UpperBoundWatcher : public BaseUpperBoundWatcher {
895 public:
896 class WatchDemon : public Demon {
897 public:
898 WatchDemon(UpperBoundWatcher* const watcher, int64_t index,
899 IntVar* const var)
900 : value_watcher_(watcher), index_(index), var_(var) {}
901 ~WatchDemon() override {}
902
903 void Run(Solver* const solver) override {
904 value_watcher_->ProcessUpperBoundWatcher(index_, var_);
905 }
906
907 private:
908 UpperBoundWatcher* const value_watcher_;
909 const int64_t index_;
910 IntVar* const var_;
911 };
912
913 class VarDemon : public Demon {
914 public:
915 explicit VarDemon(UpperBoundWatcher* const watcher)
916 : value_watcher_(watcher) {}
917 ~VarDemon() override {}
918
919 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
920
921 private:
922 UpperBoundWatcher* const value_watcher_;
923 };
924
925 UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
926 : BaseUpperBoundWatcher(solver),
927 variable_(variable),
928 var_demon_(nullptr),
929 watchers_(solver, variable->Min(), variable->Max()),
930 start_(0),
931 end_(0),
932 sorted_(false) {}
933
934 ~UpperBoundWatcher() override {}
935
936 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
937 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
938 if (watcher != nullptr) {
939 return watcher;
940 }
941 if (variable_->Max() >= value) {
942 if (variable_->Min() >= value) {
943 return solver()->MakeIntConst(1);
944 } else {
945 const std::string vname = variable_->HasName()
946 ? variable_->name()
947 : variable_->DebugString();
948 const std::string bname =
949 absl::StrFormat("Watch<%s >= %d>", vname, value);
950 IntVar* const boolvar = solver()->MakeBoolVar(bname);
951 watchers_.UnsafeRevInsert(value, boolvar);
952 if (posted_.Switched()) {
953 boolvar->WhenBound(
954 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
955 var_demon_->desinhibit(solver());
956 sorted_ = false;
957 }
958 return boolvar;
959 }
960 } else {
961 return variable_->solver()->MakeIntConst(0);
962 }
963 }
964
965 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
966 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
967 watchers_.UnsafeRevInsert(value, boolvar);
968 if (posted_.Switched() && !boolvar->Bound()) {
969 boolvar->WhenBound(
970 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
971 var_demon_->desinhibit(solver());
972 sorted_ = false;
973 }
974 }
975
976 void Post() override {
977 const int kTooSmallToSort = 8;
978 var_demon_ = solver()->RevAlloc(new VarDemon(this));
979 variable_->WhenRange(var_demon_);
980
981 if (watchers_.Size() > kTooSmallToSort) {
982 watchers_.SortActive();
983 sorted_ = true;
984 start_.SetValue(solver(), watchers_.start());
985 end_.SetValue(solver(), watchers_.end() - 1);
986 }
987
988 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
989 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
990 IntVar* const boolvar = w.second;
991 const int64_t value = w.first;
992 if (!boolvar->Bound() && value > variable_->Min() &&
993 value <= variable_->Max()) {
994 boolvar->WhenBound(
995 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
996 }
997 }
998 posted_.Switch(solver());
999 }
1000
1001 void InitialPropagate() override {
1002 const int64_t var_min = variable_->Min();
1003 const int64_t var_max = variable_->Max();
1004 if (sorted_) {
1005 while (start_.Value() <= end_.Value()) {
1006 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1007 if (w.first <= var_min) {
1008 w.second->SetValue(1);
1009 start_.Incr(solver());
1010 } else {
1011 break;
1012 }
1013 }
1014 while (end_.Value() >= start_.Value()) {
1015 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1016 if (w.first > var_max) {
1017 w.second->SetValue(0);
1018 end_.Decr(solver());
1019 } else {
1020 break;
1021 }
1022 }
1023 for (int i = start_.Value(); i <= end_.Value(); ++i) {
1024 const std::pair<int64_t, IntVar*>& w = watchers_.At(i);
1025 if (w.second->Bound()) {
1026 ProcessUpperBoundWatcher(w.first, w.second);
1027 }
1028 }
1029 if (start_.Value() > end_.Value()) {
1030 var_demon_->inhibit(solver());
1031 }
1032 } else {
1033 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1034 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1035 const int64_t value = w.first;
1036 IntVar* const boolvar = w.second;
1037
1038 if (value <= var_min) {
1039 boolvar->SetValue(1);
1040 watchers_.RemoveAt(pos);
1041 } else if (value > var_max) {
1042 boolvar->SetValue(0);
1043 watchers_.RemoveAt(pos);
1044 } else if (boolvar->Bound()) {
1045 ProcessUpperBoundWatcher(value, boolvar);
1046 watchers_.RemoveAt(pos);
1047 }
1048 }
1049 }
1050 }
1051
1052 void Accept(ModelVisitor* const visitor) const override {
1053 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1054 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1055 variable_);
1056 std::vector<int64_t> all_coefficients;
1057 std::vector<IntVar*> all_bool_vars;
1058 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1059 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1060 all_coefficients.push_back(w.first);
1061 all_bool_vars.push_back(w.second);
1062 }
1063 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1064 all_bool_vars);
1065 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1066 all_coefficients);
1067 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1068 }
1069
1070 std::string DebugString() const override {
1071 return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1072 }
1073
1074 private:
1075 void ProcessUpperBoundWatcher(int64_t value, IntVar* const boolvar) {
1076 if (boolvar->Min() == 0) {
1077 variable_->SetMax(value - 1);
1078 } else {
1079 variable_->SetMin(value);
1080 }
1081 }
1082
1083 void ProcessVar() {
1084 const int64_t var_min = variable_->Min();
1085 const int64_t var_max = variable_->Max();
1086 if (sorted_) {
1087 while (start_.Value() <= end_.Value()) {
1088 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1089 if (w.first <= var_min) {
1090 w.second->SetValue(1);
1091 start_.Incr(solver());
1092 } else {
1093 break;
1094 }
1095 }
1096 while (end_.Value() >= start_.Value()) {
1097 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1098 if (w.first > var_max) {
1099 w.second->SetValue(0);
1100 end_.Decr(solver());
1101 } else {
1102 break;
1103 }
1104 }
1105 if (start_.Value() > end_.Value()) {
1106 var_demon_->inhibit(solver());
1107 }
1108 } else {
1109 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1110 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1111 const int64_t value = w.first;
1112 IntVar* const boolvar = w.second;
1113
1114 if (value <= var_min) {
1115 boolvar->SetValue(1);
1116 watchers_.RemoveAt(pos);
1117 } else if (value > var_max) {
1118 boolvar->SetValue(0);
1119 watchers_.RemoveAt(pos);
1120 }
1121 }
1122 if (watchers_.Empty()) {
1123 var_demon_->inhibit(solver());
1124 }
1125 }
1126 }
1127
1128 DomainIntVar* const variable_;
1129 RevSwitch posted_;
1130 Demon* var_demon_;
1131 RevIntPtrMap<IntVar> watchers_;
1132 NumericalRev<int> start_;
1133 NumericalRev<int> end_;
1134 bool sorted_;
1135 };
1136
1137 // Optimized case for small maps.
1138 class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1139 public:
1140 class WatchDemon : public Demon {
1141 public:
1142 WatchDemon(DenseUpperBoundWatcher* const watcher, int64_t value,
1143 IntVar* var)
1144 : value_watcher_(watcher), value_(value), var_(var) {}
1145 ~WatchDemon() override {}
1146
1147 void Run(Solver* const solver) override {
1148 value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1149 }
1150
1151 private:
1152 DenseUpperBoundWatcher* const value_watcher_;
1153 const int64_t value_;
1154 IntVar* const var_;
1155 };
1156
1157 class VarDemon : public Demon {
1158 public:
1159 explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1160 : value_watcher_(watcher) {}
1161
1162 ~VarDemon() override {}
1163
1164 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1165
1166 private:
1167 DenseUpperBoundWatcher* const value_watcher_;
1168 };
1169
1170 DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1171 : BaseUpperBoundWatcher(solver),
1172 variable_(variable),
1173 var_demon_(nullptr),
1174 offset_(variable->Min()),
1175 watchers_(variable->Max() - variable->Min() + 1, nullptr),
1176 active_watchers_(0) {}
1177
1178 ~DenseUpperBoundWatcher() override {}
1179
1180 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
1181 if (variable_->Max() >= value) {
1182 if (variable_->Min() >= value) {
1183 return solver()->MakeIntConst(1);
1184 } else {
1185 const std::string vname = variable_->HasName()
1186 ? variable_->name()
1187 : variable_->DebugString();
1188 const std::string bname =
1189 absl::StrFormat("Watch<%s >= %d>", vname, value);
1190 IntVar* const boolvar = solver()->MakeBoolVar(bname);
1191 RevInsert(value - offset_, boolvar);
1192 if (posted_.Switched()) {
1193 boolvar->WhenBound(
1194 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1195 var_demon_->desinhibit(solver());
1196 }
1197 return boolvar;
1198 }
1199 } else {
1200 return variable_->solver()->MakeIntConst(0);
1201 }
1202 }
1203
1204 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
1205 const int index = value - offset_;
1206 CHECK(watchers_[index] == nullptr);
1207 if (!boolvar->Bound()) {
1208 RevInsert(index, boolvar);
1209 if (posted_.Switched() && !boolvar->Bound()) {
1210 boolvar->WhenBound(
1211 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1212 var_demon_->desinhibit(solver());
1213 }
1214 }
1215 }
1216
1217 void Post() override {
1218 var_demon_ = solver()->RevAlloc(new VarDemon(this));
1219 variable_->WhenRange(var_demon_);
1220 for (int pos = 0; pos < watchers_.size(); ++pos) {
1221 const int64_t value = pos + offset_;
1222 IntVar* const boolvar = watchers_[pos];
1223 if (boolvar != nullptr && !boolvar->Bound() &&
1224 value > variable_->Min() && value <= variable_->Max()) {
1225 boolvar->WhenBound(
1226 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1227 }
1228 }
1229 posted_.Switch(solver());
1230 }
1231
1232 void InitialPropagate() override {
1233 for (int pos = 0; pos < watchers_.size(); ++pos) {
1234 IntVar* const boolvar = watchers_[pos];
1235 if (boolvar == nullptr) continue;
1236 const int64_t value = pos + offset_;
1237 if (value <= variable_->Min()) {
1238 boolvar->SetValue(1);
1239 RevRemove(pos);
1240 } else if (value > variable_->Max()) {
1241 boolvar->SetValue(0);
1242 RevRemove(pos);
1243 } else if (boolvar->Bound()) {
1244 ProcessUpperBoundWatcher(value, boolvar);
1245 RevRemove(pos);
1246 }
1247 }
1248 if (active_watchers_.Value() == 0) {
1249 var_demon_->inhibit(solver());
1250 }
1251 }
1252
1253 void ProcessUpperBoundWatcher(int64_t value, IntVar* boolvar) {
1254 if (boolvar->Min() == 0) {
1255 variable_->SetMax(value - 1);
1256 } else {
1257 variable_->SetMin(value);
1258 }
1259 }
1260
1261 void ProcessVar() {
1262 const int64_t old_min_index = variable_->OldMin() - offset_;
1263 const int64_t old_max_index = variable_->OldMax() - offset_;
1264 const int64_t min_index = variable_->Min() - offset_;
1265 const int64_t max_index = variable_->Max() - offset_;
1266 for (int pos = old_min_index; pos <= min_index; ++pos) {
1267 IntVar* const boolvar = watchers_[pos];
1268 if (boolvar != nullptr) {
1269 boolvar->SetValue(1);
1270 RevRemove(pos);
1271 }
1272 }
1273
1274 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1275 IntVar* const boolvar = watchers_[pos];
1276 if (boolvar != nullptr) {
1277 boolvar->SetValue(0);
1278 RevRemove(pos);
1279 }
1280 }
1281 if (active_watchers_.Value() == 0) {
1282 var_demon_->inhibit(solver());
1283 }
1284 }
1285
1286 void RevRemove(int pos) {
1287 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1288 watchers_[pos] = nullptr;
1289 active_watchers_.Decr(solver());
1290 }
1291
1292 void RevInsert(int pos, IntVar* boolvar) {
1293 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1294 watchers_[pos] = boolvar;
1295 active_watchers_.Incr(solver());
1296 }
1297
1298 void Accept(ModelVisitor* const visitor) const override {
1299 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1300 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1301 variable_);
1302 std::vector<int64_t> all_coefficients;
1303 std::vector<IntVar*> all_bool_vars;
1304 for (int position = 0; position < watchers_.size(); ++position) {
1305 if (watchers_[position] != nullptr) {
1306 all_coefficients.push_back(position + offset_);
1307 all_bool_vars.push_back(watchers_[position]);
1308 }
1309 }
1310 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1311 all_bool_vars);
1312 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1313 all_coefficients);
1314 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1315 }
1316
1317 std::string DebugString() const override {
1318 return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1319 variable_->DebugString());
1320 }
1321
1322 private:
1323 DomainIntVar* const variable_;
1324 RevSwitch posted_;
1325 Demon* var_demon_;
1326 const int64_t offset_;
1327 std::vector<IntVar*> watchers_;
1328 NumericalRev<int> active_watchers_;
1329 };
1330
1331 // ----- Main Class -----
1332 DomainIntVar(Solver* s, int64_t vmin, int64_t vmax, const std::string& name);
1333 DomainIntVar(Solver* s, absl::Span<const int64_t> sorted_values,
1334 const std::string& name);
1335 ~DomainIntVar() override;
1336
1337 int64_t Min() const override { return min_.Value(); }
1338 void SetMin(int64_t m) override;
1339 int64_t Max() const override { return max_.Value(); }
1340 void SetMax(int64_t m) override;
1341 void SetRange(int64_t mi, int64_t ma) override;
1342 void SetValue(int64_t v) override;
1343 bool Bound() const override { return (min_.Value() == max_.Value()); }
1344 int64_t Value() const override {
1345 CHECK_EQ(min_.Value(), max_.Value())
1346 << " variable " << DebugString() << " is not bound.";
1347 return min_.Value();
1348 }
1349 void RemoveValue(int64_t v) override;
1350 void RemoveInterval(int64_t l, int64_t u) override;
1351 void CreateBits();
1352 void WhenBound(Demon* d) override {
1353 if (min_.Value() != max_.Value()) {
1354 if (d->priority() == Solver::DELAYED_PRIORITY) {
1355 delayed_bound_demons_.PushIfNotTop(solver(),
1356 solver()->RegisterDemon(d));
1357 } else {
1358 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1359 }
1360 }
1361 }
1362 void WhenRange(Demon* d) override {
1363 if (min_.Value() != max_.Value()) {
1364 if (d->priority() == Solver::DELAYED_PRIORITY) {
1365 delayed_range_demons_.PushIfNotTop(solver(),
1366 solver()->RegisterDemon(d));
1367 } else {
1368 range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1369 }
1370 }
1371 }
1372 void WhenDomain(Demon* d) override {
1373 if (min_.Value() != max_.Value()) {
1374 if (d->priority() == Solver::DELAYED_PRIORITY) {
1375 delayed_domain_demons_.PushIfNotTop(solver(),
1376 solver()->RegisterDemon(d));
1377 } else {
1378 domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1379 }
1380 }
1381 }
1382
1383 IntVar* IsEqual(int64_t constant) override {
1384 Solver* const s = solver();
1385 if (constant == min_.Value() && value_watcher_ == nullptr) {
1386 return s->MakeIsLessOrEqualCstVar(this, constant);
1387 }
1388 if (constant == max_.Value() && value_watcher_ == nullptr) {
1389 return s->MakeIsGreaterOrEqualCstVar(this, constant);
1390 }
1391 if (!Contains(constant)) {
1392 return s->MakeIntConst(int64_t{0});
1393 }
1394 if (Bound() && min_.Value() == constant) {
1395 return s->MakeIntConst(int64_t{1});
1396 }
1397 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1398 this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1399 if (cache != nullptr) {
1400 return cache->Var();
1401 } else {
1402 if (value_watcher_ == nullptr) {
1403 if (CapSub(Max(), Min()) <= 256) {
1404 solver()->SaveAndSetValue(
1405 reinterpret_cast<void**>(&value_watcher_),
1406 reinterpret_cast<void*>(
1407 solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1408
1409 } else {
1410 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1411 reinterpret_cast<void*>(solver()->RevAlloc(
1412 new ValueWatcher(solver(), this))));
1413 }
1414 solver()->AddConstraint(value_watcher_);
1415 }
1416 IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1417 s->Cache()->InsertExprConstantExpression(
1418 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1419 return boolvar;
1420 }
1421 }
1422
1423 Constraint* SetIsEqual(absl::Span<const int64_t> values,
1424 const std::vector<IntVar*>& vars) {
1425 if (value_watcher_ == nullptr) {
1426 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1427 reinterpret_cast<void*>(solver()->RevAlloc(
1428 new ValueWatcher(solver(), this))));
1429 for (int i = 0; i < vars.size(); ++i) {
1430 value_watcher_->SetValueWatcher(vars[i], values[i]);
1431 }
1432 }
1433 return value_watcher_;
1434 }
1435
1436 IntVar* IsDifferent(int64_t constant) override {
1437 Solver* const s = solver();
1438 if (constant == min_.Value() && value_watcher_ == nullptr) {
1439 return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1440 }
1441 if (constant == max_.Value() && value_watcher_ == nullptr) {
1442 return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1443 }
1444 if (!Contains(constant)) {
1445 return s->MakeIntConst(int64_t{1});
1446 }
1447 if (Bound() && min_.Value() == constant) {
1448 return s->MakeIntConst(int64_t{0});
1449 }
1450 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1452 if (cache != nullptr) {
1453 return cache->Var();
1454 } else {
1455 IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1456 s->Cache()->InsertExprConstantExpression(
1457 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1458 return boolvar;
1459 }
1460 }
1461
1462 IntVar* IsGreaterOrEqual(int64_t constant) override {
1463 Solver* const s = solver();
1464 if (max_.Value() < constant) {
1465 return s->MakeIntConst(int64_t{0});
1466 }
1467 if (min_.Value() >= constant) {
1468 return s->MakeIntConst(int64_t{1});
1469 }
1470 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1472 if (cache != nullptr) {
1473 return cache->Var();
1474 } else {
1475 if (bound_watcher_ == nullptr) {
1476 if (CapSub(Max(), Min()) <= 256) {
1477 solver()->SaveAndSetValue(
1478 reinterpret_cast<void**>(&bound_watcher_),
1479 reinterpret_cast<void*>(solver()->RevAlloc(
1480 new DenseUpperBoundWatcher(solver(), this))));
1481 solver()->AddConstraint(bound_watcher_);
1482 } else {
1483 solver()->SaveAndSetValue(
1484 reinterpret_cast<void**>(&bound_watcher_),
1485 reinterpret_cast<void*>(
1486 solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1487 solver()->AddConstraint(bound_watcher_);
1488 }
1489 }
1490 IntVar* const boolvar =
1491 bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1492 s->Cache()->InsertExprConstantExpression(
1493 boolvar, this, constant,
1495 return boolvar;
1496 }
1497 }
1498
1499 Constraint* SetIsGreaterOrEqual(absl::Span<const int64_t> values,
1500 const std::vector<IntVar*>& vars) {
1501 if (bound_watcher_ == nullptr) {
1502 if (CapSub(Max(), Min()) <= 256) {
1503 solver()->SaveAndSetValue(
1504 reinterpret_cast<void**>(&bound_watcher_),
1505 reinterpret_cast<void*>(solver()->RevAlloc(
1506 new DenseUpperBoundWatcher(solver(), this))));
1507 solver()->AddConstraint(bound_watcher_);
1508 } else {
1509 solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1510 reinterpret_cast<void*>(solver()->RevAlloc(
1511 new UpperBoundWatcher(solver(), this))));
1512 solver()->AddConstraint(bound_watcher_);
1513 }
1514 for (int i = 0; i < values.size(); ++i) {
1515 bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1516 }
1517 }
1518 return bound_watcher_;
1519 }
1520
1521 IntVar* IsLessOrEqual(int64_t constant) override {
1522 Solver* const s = solver();
1523 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1525 if (cache != nullptr) {
1526 return cache->Var();
1527 } else {
1528 IntVar* const boolvar =
1529 s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1530 s->Cache()->InsertExprConstantExpression(
1531 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1532 return boolvar;
1533 }
1534 }
1535
1536 void Process();
1537 void Push();
1538 void CleanInProcess();
1539 uint64_t Size() const override {
1540 if (bits_ != nullptr) return bits_->Size();
1541 return (static_cast<uint64_t>(max_.Value()) -
1542 static_cast<uint64_t>(min_.Value()) + 1);
1543 }
1544 bool Contains(int64_t v) const override {
1545 if (v < min_.Value() || v > max_.Value()) return false;
1546 return (bits_ == nullptr ? true : bits_->Contains(v));
1547 }
1548 IntVarIterator* MakeHoleIterator(bool reversible) const override;
1549 IntVarIterator* MakeDomainIterator(bool reversible) const override;
1550 int64_t OldMin() const override { return std::min(old_min_, min_.Value()); }
1551 int64_t OldMax() const override { return std::max(old_max_, max_.Value()); }
1552
1553 std::string DebugString() const override;
1554 BitSet* bitset() const { return bits_; }
1555 int VarType() const override { return DOMAIN_INT_VAR; }
1556 std::string BaseName() const override { return "IntegerVar"; }
1557
1558 friend class PlusCstDomainIntVar;
1559 friend class LinkExprAndDomainIntVar;
1560
1561 private:
1562 void CheckOldMin() {
1563 if (old_min_ > min_.Value()) {
1564 old_min_ = min_.Value();
1565 }
1566 }
1567 void CheckOldMax() {
1568 if (old_max_ < max_.Value()) {
1569 old_max_ = max_.Value();
1570 }
1571 }
1572 Rev<int64_t> min_;
1573 Rev<int64_t> max_;
1574 int64_t old_min_;
1575 int64_t old_max_;
1576 int64_t new_min_;
1577 int64_t new_max_;
1578 SimpleRevFIFO<Demon*> bound_demons_;
1579 SimpleRevFIFO<Demon*> range_demons_;
1580 SimpleRevFIFO<Demon*> domain_demons_;
1581 SimpleRevFIFO<Demon*> delayed_bound_demons_;
1582 SimpleRevFIFO<Demon*> delayed_range_demons_;
1583 SimpleRevFIFO<Demon*> delayed_domain_demons_;
1584 QueueHandler handler_;
1585 bool in_process_;
1586 BitSet* bits_;
1587 BaseValueWatcher* value_watcher_;
1588 BaseUpperBoundWatcher* bound_watcher_;
1589};
1590
1591// ----- BitSet -----
1592
1593// Return whether an integer interval [a..b] (inclusive) contains at most
1594// K values, i.e. b - a < K, in a way that's robust to overflows.
1595// For performance reasons, in opt mode it doesn't check that [a, b] is a
1596// valid interval, nor that K is nonnegative.
1597inline bool ClosedIntervalNoLargerThan(int64_t a, int64_t b, int64_t K) {
1598 DCHECK_LE(a, b);
1599 DCHECK_GE(K, 0);
1600 if (a > 0) {
1601 return a > b - K;
1602 } else {
1603 return a + K > b;
1604 }
1605}
1606
1607class SimpleBitSet : public DomainIntVar::BitSet {
1608 public:
1609 SimpleBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1610 : BitSet(s),
1611 bits_(nullptr),
1612 stamps_(nullptr),
1613 omin_(vmin),
1614 omax_(vmax),
1615 size_(vmax - vmin + 1),
1616 bsize_(BitLength64(size_.Value())) {
1617 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1618 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1619 bits_ = std::unique_ptr<uint64_t[]>(new uint64_t[bsize_]);
1620 stamps_ = std::unique_ptr<uint64_t[]>(new uint64_t[bsize_]);
1621 for (int i = 0; i < bsize_; ++i) {
1622 const int bs =
1623 (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1624 bits_[i] = kAllBits64 >> bs;
1625 stamps_[i] = s->stamp() - 1;
1626 }
1627 }
1628
1629 SimpleBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1630 int64_t vmin, int64_t vmax)
1631 : BitSet(s),
1632 bits_(nullptr),
1633 stamps_(nullptr),
1634 omin_(vmin),
1635 omax_(vmax),
1636 size_(sorted_values.size()),
1637 bsize_(BitLength64(vmax - vmin + 1)) {
1638 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1639 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1640 bits_ = std::unique_ptr<uint64_t[]>(new uint64_t[bsize_]);
1641 stamps_ = std::unique_ptr<uint64_t[]>(new uint64_t[bsize_]);
1642 for (int i = 0; i < bsize_; ++i) {
1643 bits_[i] = uint64_t{0};
1644 stamps_[i] = s->stamp() - 1;
1645 }
1646 for (int i = 0; i < sorted_values.size(); ++i) {
1647 const int64_t val = sorted_values[i];
1648 DCHECK(!bit(val));
1649 const int offset = BitOffset64(val - omin_);
1650 const int pos = BitPos64(val - omin_);
1651 bits_[offset] |= OneBit64(pos);
1652 }
1653 }
1654
1655 ~SimpleBitSet() override = default;
1656
1657 bool bit(int64_t val) const { return IsBitSet64(bits_.get(), val - omin_); }
1658
1659 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1660 DCHECK_GE(nmin, cmin);
1661 DCHECK_LE(nmin, cmax);
1662 DCHECK_LE(cmin, cmax);
1663 DCHECK_GE(cmin, omin_);
1664 DCHECK_LE(cmax, omax_);
1665 const int64_t new_min = UnsafeLeastSignificantBitPosition64(
1666 bits_.get(), nmin - omin_, cmax - omin_) +
1667 omin_;
1668 const uint64_t removed_bits =
1669 BitCountRange64(bits_.get(), cmin - omin_, new_min - omin_ - 1);
1670 size_.Add(solver_, -removed_bits);
1671 return new_min;
1672 }
1673
1674 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1675 DCHECK_GE(nmax, cmin);
1676 DCHECK_LE(nmax, cmax);
1677 DCHECK_LE(cmin, cmax);
1678 DCHECK_GE(cmin, omin_);
1679 DCHECK_LE(cmax, omax_);
1680 const int64_t new_max = UnsafeMostSignificantBitPosition64(
1681 bits_.get(), cmin - omin_, nmax - omin_) +
1682 omin_;
1683 const uint64_t removed_bits =
1684 BitCountRange64(bits_.get(), new_max - omin_ + 1, cmax - omin_);
1685 size_.Add(solver_, -removed_bits);
1686 return new_max;
1687 }
1688
1689 bool SetValue(int64_t val) override {
1690 DCHECK_GE(val, omin_);
1691 DCHECK_LE(val, omax_);
1692 if (bit(val)) {
1693 size_.SetValue(solver_, 1);
1694 return true;
1695 }
1696 return false;
1697 }
1698
1699 bool Contains(int64_t val) const override {
1700 DCHECK_GE(val, omin_);
1701 DCHECK_LE(val, omax_);
1702 return bit(val);
1703 }
1704
1705 bool RemoveValue(int64_t val) override {
1706 if (val < omin_ || val > omax_ || !bit(val)) {
1707 return false;
1708 }
1709 // Bitset.
1710 const int64_t val_offset = val - omin_;
1711 const int offset = BitOffset64(val_offset);
1712 const uint64_t current_stamp = solver_->stamp();
1713 if (stamps_[offset] < current_stamp) {
1714 stamps_[offset] = current_stamp;
1715 solver_->SaveValue(&bits_[offset]);
1716 }
1717 const int pos = BitPos64(val_offset);
1718 bits_[offset] &= ~OneBit64(pos);
1719 // Size.
1720 size_.Decr(solver_);
1721 // Holes.
1722 InitHoles();
1723 AddHole(val);
1724 return true;
1725 }
1726 uint64_t Size() const override { return size_.Value(); }
1727
1728 std::string DebugString() const override {
1729 std::string out;
1730 absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1731 for (int i = 0; i < bsize_; ++i) {
1732 absl::StrAppendFormat(&out, "%x", bits_[i]);
1733 }
1734 out += ")";
1735 return out;
1736 }
1737
1738 void DelayRemoveValue(int64_t val) override { removed_.push_back(val); }
1739
1740 void ApplyRemovedValues(DomainIntVar* var) override {
1741 std::sort(removed_.begin(), removed_.end());
1742 for (std::vector<int64_t>::iterator it = removed_.begin();
1743 it != removed_.end(); ++it) {
1744 var->RemoveValue(*it);
1745 }
1746 }
1747
1748 void ClearRemovedValues() override { removed_.clear(); }
1749
1750 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1751 std::string out;
1752 DCHECK(bit(min));
1753 DCHECK(bit(max));
1754 if (max != min) {
1755 int cumul = true;
1756 int64_t start_cumul = min;
1757 for (int64_t v = min + 1; v < max; ++v) {
1758 if (bit(v)) {
1759 if (!cumul) {
1760 cumul = true;
1761 start_cumul = v;
1762 }
1763 } else {
1764 if (cumul) {
1765 if (v == start_cumul + 1) {
1766 absl::StrAppendFormat(&out, "%d ", start_cumul);
1767 } else if (v == start_cumul + 2) {
1768 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1769 } else {
1770 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1771 }
1772 cumul = false;
1773 }
1774 }
1775 }
1776 if (cumul) {
1777 if (max == start_cumul + 1) {
1778 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1779 } else {
1780 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1781 }
1782 } else {
1783 absl::StrAppendFormat(&out, "%d", max);
1784 }
1785 } else {
1786 absl::StrAppendFormat(&out, "%d", min);
1787 }
1788 return out;
1789 }
1790
1791 DomainIntVar::BitSetIterator* MakeIterator() override {
1792 return new DomainIntVar::BitSetIterator(bits_.get(), omin_);
1793 }
1794
1795 private:
1796 std::unique_ptr<uint64_t[]> bits_;
1797 std::unique_ptr<uint64_t[]> stamps_;
1798 const int64_t omin_;
1799 const int64_t omax_;
1800 NumericalRev<int64_t> size_;
1801 const int bsize_;
1802 std::vector<int64_t> removed_;
1803};
1804
1805// This is a special case where the bitset fits into one 64 bit integer.
1806// In that case, there are no offset to compute.
1807// Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1808class SmallBitSet : public DomainIntVar::BitSet {
1809 public:
1810 SmallBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1811 : BitSet(s),
1812 bits_(uint64_t{0}),
1813 stamp_(s->stamp() - 1),
1814 omin_(vmin),
1815 omax_(vmax),
1816 size_(vmax - vmin + 1) {
1817 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1818 bits_ = OneRange64(0, size_.Value() - 1);
1819 }
1820
1821 SmallBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1822 int64_t vmin, int64_t vmax)
1823 : BitSet(s),
1824 bits_(uint64_t{0}),
1825 stamp_(s->stamp() - 1),
1826 omin_(vmin),
1827 omax_(vmax),
1828 size_(sorted_values.size()) {
1829 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1830 // We know the array is sorted and does not contains duplicate values.
1831 for (int i = 0; i < sorted_values.size(); ++i) {
1832 const int64_t val = sorted_values[i];
1833 DCHECK_GE(val, vmin);
1834 DCHECK_LE(val, vmax);
1835 DCHECK(!IsBitSet64(&bits_, val - omin_));
1836 bits_ |= OneBit64(val - omin_);
1837 }
1838 }
1839
1840 ~SmallBitSet() override {}
1841
1842 bool bit(int64_t val) const {
1843 DCHECK_GE(val, omin_);
1844 DCHECK_LE(val, omax_);
1845 return (bits_ & OneBit64(val - omin_)) != 0;
1846 }
1847
1848 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1849 DCHECK_GE(nmin, cmin);
1850 DCHECK_LE(nmin, cmax);
1851 DCHECK_LE(cmin, cmax);
1852 DCHECK_GE(cmin, omin_);
1853 DCHECK_LE(cmax, omax_);
1854 // We do not clean the bits between cmin and nmin.
1855 // But we use mask to look only at 'active' bits.
1856
1857 // Create the mask and compute new bits
1858 const uint64_t new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1859 if (new_bits != uint64_t{0}) {
1860 // Compute new size and new min
1861 size_.SetValue(solver_, BitCount64(new_bits));
1862 if (bit(nmin)) { // Common case, the new min is inside the bitset
1863 return nmin;
1864 }
1865 return LeastSignificantBitPosition64(new_bits) + omin_;
1866 } else { // == 0 -> Fail()
1867 solver_->Fail();
1868 return std::numeric_limits<int64_t>::max();
1869 }
1870 }
1871
1872 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1873 DCHECK_GE(nmax, cmin);
1874 DCHECK_LE(nmax, cmax);
1875 DCHECK_LE(cmin, cmax);
1876 DCHECK_GE(cmin, omin_);
1877 DCHECK_LE(cmax, omax_);
1878 // We do not clean the bits between nmax and cmax.
1879 // But we use mask to look only at 'active' bits.
1880
1881 // Create the mask and compute new_bits
1882 const uint64_t new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1883 if (new_bits != uint64_t{0}) {
1884 // Compute new size and new min
1885 size_.SetValue(solver_, BitCount64(new_bits));
1886 if (bit(nmax)) { // Common case, the new max is inside the bitset
1887 return nmax;
1888 }
1889 return MostSignificantBitPosition64(new_bits) + omin_;
1890 } else { // == 0 -> Fail()
1891 solver_->Fail();
1892 return std::numeric_limits<int64_t>::min();
1893 }
1894 }
1895
1896 bool SetValue(int64_t val) override {
1897 DCHECK_GE(val, omin_);
1898 DCHECK_LE(val, omax_);
1899 // We do not clean the bits. We will use masks to ignore the bits
1900 // that should have been cleaned.
1901 if (bit(val)) {
1902 size_.SetValue(solver_, 1);
1903 return true;
1904 }
1905 return false;
1906 }
1907
1908 bool Contains(int64_t val) const override {
1909 DCHECK_GE(val, omin_);
1910 DCHECK_LE(val, omax_);
1911 return bit(val);
1912 }
1913
1914 bool RemoveValue(int64_t val) override {
1915 DCHECK_GE(val, omin_);
1916 DCHECK_LE(val, omax_);
1917 if (bit(val)) {
1918 // Bitset.
1919 const uint64_t current_stamp = solver_->stamp();
1920 if (stamp_ < current_stamp) {
1921 stamp_ = current_stamp;
1922 solver_->SaveValue(&bits_);
1923 }
1924 bits_ &= ~OneBit64(val - omin_);
1925 DCHECK(!bit(val));
1926 // Size.
1927 size_.Decr(solver_);
1928 // Holes.
1929 InitHoles();
1930 AddHole(val);
1931 return true;
1932 } else {
1933 return false;
1934 }
1935 }
1936
1937 uint64_t Size() const override { return size_.Value(); }
1938
1939 std::string DebugString() const override {
1940 return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1941 }
1942
1943 void DelayRemoveValue(int64_t val) override {
1944 DCHECK_GE(val, omin_);
1945 DCHECK_LE(val, omax_);
1946 removed_.push_back(val);
1947 }
1948
1949 void ApplyRemovedValues(DomainIntVar* var) override {
1950 std::sort(removed_.begin(), removed_.end());
1951 for (std::vector<int64_t>::iterator it = removed_.begin();
1952 it != removed_.end(); ++it) {
1953 var->RemoveValue(*it);
1954 }
1955 }
1956
1957 void ClearRemovedValues() override { removed_.clear(); }
1958
1959 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1960 std::string out;
1961 DCHECK(bit(min));
1962 DCHECK(bit(max));
1963 if (max != min) {
1964 int cumul = true;
1965 int64_t start_cumul = min;
1966 for (int64_t v = min + 1; v < max; ++v) {
1967 if (bit(v)) {
1968 if (!cumul) {
1969 cumul = true;
1970 start_cumul = v;
1971 }
1972 } else {
1973 if (cumul) {
1974 if (v == start_cumul + 1) {
1975 absl::StrAppendFormat(&out, "%d ", start_cumul);
1976 } else if (v == start_cumul + 2) {
1977 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1978 } else {
1979 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1980 }
1981 cumul = false;
1982 }
1983 }
1984 }
1985 if (cumul) {
1986 if (max == start_cumul + 1) {
1987 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1988 } else {
1989 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1990 }
1991 } else {
1992 absl::StrAppendFormat(&out, "%d", max);
1993 }
1994 } else {
1995 absl::StrAppendFormat(&out, "%d", min);
1996 }
1997 return out;
1998 }
1999
2000 DomainIntVar::BitSetIterator* MakeIterator() override {
2001 return new DomainIntVar::BitSetIterator(&bits_, omin_);
2002 }
2003
2004 private:
2005 uint64_t bits_;
2006 uint64_t stamp_;
2007 const int64_t omin_;
2008 const int64_t omax_;
2009 NumericalRev<int64_t> size_;
2010 std::vector<int64_t> removed_;
2011};
2012
2013class EmptyIterator : public IntVarIterator {
2014 public:
2015 ~EmptyIterator() override {}
2016 void Init() override {}
2017 bool Ok() const override { return false; }
2018 int64_t Value() const override {
2019 LOG(FATAL) << "Should not be called";
2020 return 0LL;
2021 }
2022 void Next() override {}
2023};
2024
2025class RangeIterator : public IntVarIterator {
2026 public:
2027 explicit RangeIterator(const IntVar* const var)
2028 : var_(var),
2029 min_(std::numeric_limits<int64_t>::max()),
2030 max_(std::numeric_limits<int64_t>::min()),
2031 current_(-1) {}
2032
2033 ~RangeIterator() override {}
2034
2035 void Init() override {
2036 min_ = var_->Min();
2037 max_ = var_->Max();
2038 current_ = min_;
2039 }
2040
2041 bool Ok() const override { return current_ <= max_; }
2042
2043 int64_t Value() const override { return current_; }
2044
2045 void Next() override { current_++; }
2046
2047 private:
2048 const IntVar* const var_;
2049 int64_t min_;
2050 int64_t max_;
2051 int64_t current_;
2052};
2053
2054class DomainIntVarHoleIterator : public IntVarIterator {
2055 public:
2056 explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2057 : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2058
2059 ~DomainIntVarHoleIterator() override {}
2060
2061 void Init() override {
2062 bits_ = var_->bitset();
2063 if (bits_ != nullptr) {
2064 bits_->InitHoles();
2065 values_ = bits_->Holes().data();
2066 size_ = bits_->Holes().size();
2067 } else {
2068 values_ = nullptr;
2069 size_ = 0;
2070 }
2071 index_ = 0;
2072 }
2073
2074 bool Ok() const override { return index_ < size_; }
2075
2076 int64_t Value() const override {
2077 DCHECK(bits_ != nullptr);
2078 DCHECK(index_ < size_);
2079 return values_[index_];
2080 }
2081
2082 void Next() override { index_++; }
2083
2084 private:
2085 const DomainIntVar* const var_;
2086 DomainIntVar::BitSet* bits_;
2087 const int64_t* values_;
2088 int size_;
2089 int index_;
2090};
2091
2092class DomainIntVarDomainIterator : public IntVarIterator {
2093 public:
2094 explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2095 bool reversible)
2096 : var_(v),
2097 bitset_iterator_(nullptr),
2098 min_(std::numeric_limits<int64_t>::max()),
2099 max_(std::numeric_limits<int64_t>::min()),
2100 current_(-1),
2101 reversible_(reversible) {}
2102
2103 ~DomainIntVarDomainIterator() override {
2104 if (!reversible_ && bitset_iterator_) {
2105 delete bitset_iterator_;
2106 }
2107 }
2108
2109 void Init() override {
2110 if (var_->bitset() != nullptr && !var_->Bound()) {
2111 if (reversible_) {
2112 if (!bitset_iterator_) {
2113 Solver* const solver = var_->solver();
2114 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2115 bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2116 }
2117 } else {
2118 if (bitset_iterator_) {
2119 delete bitset_iterator_;
2120 }
2121 bitset_iterator_ = var_->bitset()->MakeIterator();
2122 }
2123 bitset_iterator_->Init(var_->Min(), var_->Max());
2124 } else {
2125 if (bitset_iterator_) {
2126 if (reversible_) {
2127 Solver* const solver = var_->solver();
2128 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2129 } else {
2130 delete bitset_iterator_;
2131 }
2132 bitset_iterator_ = nullptr;
2133 }
2134 min_ = var_->Min();
2135 max_ = var_->Max();
2136 current_ = min_;
2137 }
2138 }
2139
2140 bool Ok() const override {
2141 return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2142 }
2143
2144 int64_t Value() const override {
2145 return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2146 }
2147
2148 void Next() override {
2149 if (bitset_iterator_) {
2150 bitset_iterator_->Next();
2151 } else {
2152 current_++;
2153 }
2154 }
2155
2156 private:
2157 const DomainIntVar* const var_;
2158 DomainIntVar::BitSetIterator* bitset_iterator_;
2159 int64_t min_;
2160 int64_t max_;
2161 int64_t current_;
2162 const bool reversible_;
2163};
2164
2165class UnaryIterator : public IntVarIterator {
2166 public:
2167 UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2168 : iterator_(hole ? v->MakeHoleIterator(reversible)
2169 : v->MakeDomainIterator(reversible)),
2170 reversible_(reversible) {}
2171
2172 ~UnaryIterator() override {
2173 if (!reversible_) {
2174 delete iterator_;
2175 }
2176 }
2177
2178 void Init() override { iterator_->Init(); }
2179
2180 bool Ok() const override { return iterator_->Ok(); }
2181
2182 void Next() override { iterator_->Next(); }
2183
2184 protected:
2185 IntVarIterator* const iterator_;
2186 const bool reversible_;
2187};
2188
2189DomainIntVar::DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
2190 const std::string& name)
2191 : IntVar(s, name),
2192 min_(vmin),
2193 max_(vmax),
2194 old_min_(vmin),
2195 old_max_(vmax),
2196 new_min_(vmin),
2197 new_max_(vmax),
2198 handler_(this),
2199 in_process_(false),
2200 bits_(nullptr),
2201 value_watcher_(nullptr),
2202 bound_watcher_(nullptr) {}
2203
2204DomainIntVar::DomainIntVar(Solver* const s,
2205 absl::Span<const int64_t> sorted_values,
2206 const std::string& name)
2207 : IntVar(s, name),
2208 min_(std::numeric_limits<int64_t>::max()),
2209 max_(std::numeric_limits<int64_t>::min()),
2210 old_min_(std::numeric_limits<int64_t>::max()),
2211 old_max_(std::numeric_limits<int64_t>::min()),
2212 new_min_(std::numeric_limits<int64_t>::max()),
2213 new_max_(std::numeric_limits<int64_t>::min()),
2214 handler_(this),
2215 in_process_(false),
2216 bits_(nullptr),
2217 value_watcher_(nullptr),
2218 bound_watcher_(nullptr) {
2219 CHECK_GE(sorted_values.size(), 1);
2220 // We know that the vector is sorted and does not have duplicate values.
2221 const int64_t vmin = sorted_values.front();
2222 const int64_t vmax = sorted_values.back();
2223 const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2224
2225 min_.SetValue(solver(), vmin);
2226 old_min_ = vmin;
2227 new_min_ = vmin;
2228 max_.SetValue(solver(), vmax);
2229 old_max_ = vmax;
2230 new_max_ = vmax;
2231
2232 if (!contiguous) {
2233 if (vmax - vmin + 1 < 65) {
2234 bits_ = solver()->RevAlloc(
2235 new SmallBitSet(solver(), sorted_values, vmin, vmax));
2236 } else {
2237 bits_ = solver()->RevAlloc(
2238 new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2239 }
2240 }
2241}
2242
2243DomainIntVar::~DomainIntVar() {}
2244
2245void DomainIntVar::SetMin(int64_t m) {
2246 if (m <= min_.Value()) return;
2247 if (m > max_.Value()) solver()->Fail();
2248 if (in_process_) {
2249 if (m > new_min_) {
2250 new_min_ = m;
2251 if (new_min_ > new_max_) {
2252 solver()->Fail();
2253 }
2254 }
2255 } else {
2256 CheckOldMin();
2257 const int64_t new_min =
2258 (bits_ == nullptr
2259 ? m
2260 : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2261 min_.SetValue(solver(), new_min);
2262 if (min_.Value() > max_.Value()) {
2263 solver()->Fail();
2264 }
2265 Push();
2266 }
2267}
2268
2269void DomainIntVar::SetMax(int64_t m) {
2270 if (m >= max_.Value()) return;
2271 if (m < min_.Value()) solver()->Fail();
2272 if (in_process_) {
2273 if (m < new_max_) {
2274 new_max_ = m;
2275 if (new_max_ < new_min_) {
2276 solver()->Fail();
2277 }
2278 }
2279 } else {
2280 CheckOldMax();
2281 const int64_t new_max =
2282 (bits_ == nullptr
2283 ? m
2284 : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2285 max_.SetValue(solver(), new_max);
2286 if (min_.Value() > max_.Value()) {
2287 solver()->Fail();
2288 }
2289 Push();
2290 }
2291}
2292
2293void DomainIntVar::SetRange(int64_t mi, int64_t ma) {
2294 if (mi == ma) {
2295 SetValue(mi);
2296 } else {
2297 if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2298 if (mi <= min_.Value() && ma >= max_.Value()) return;
2299 if (in_process_) {
2300 if (ma < new_max_) {
2301 new_max_ = ma;
2302 }
2303 if (mi > new_min_) {
2304 new_min_ = mi;
2305 }
2306 if (new_min_ > new_max_) {
2307 solver()->Fail();
2308 }
2309 } else {
2310 if (mi > min_.Value()) {
2311 CheckOldMin();
2312 const int64_t new_min =
2313 (bits_ == nullptr
2314 ? mi
2315 : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2316 min_.SetValue(solver(), new_min);
2317 }
2318 if (min_.Value() > ma) {
2319 solver()->Fail();
2320 }
2321 if (ma < max_.Value()) {
2322 CheckOldMax();
2323 const int64_t new_max =
2324 (bits_ == nullptr
2325 ? ma
2326 : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2327 max_.SetValue(solver(), new_max);
2328 }
2329 if (min_.Value() > max_.Value()) {
2330 solver()->Fail();
2331 }
2332 Push();
2333 }
2334 }
2335}
2336
2337void DomainIntVar::SetValue(int64_t v) {
2338 if (v != min_.Value() || v != max_.Value()) {
2339 if (v < min_.Value() || v > max_.Value()) {
2340 solver()->Fail();
2341 }
2342 if (in_process_) {
2343 if (v > new_max_ || v < new_min_) {
2344 solver()->Fail();
2345 }
2346 new_min_ = v;
2347 new_max_ = v;
2348 } else {
2349 if (bits_ && !bits_->SetValue(v)) {
2350 solver()->Fail();
2351 }
2352 CheckOldMin();
2353 CheckOldMax();
2354 min_.SetValue(solver(), v);
2355 max_.SetValue(solver(), v);
2356 Push();
2357 }
2358 }
2359}
2360
2361void DomainIntVar::RemoveValue(int64_t v) {
2362 if (v < min_.Value() || v > max_.Value()) return;
2363 if (v == min_.Value()) {
2364 SetMin(v + 1);
2365 } else if (v == max_.Value()) {
2366 SetMax(v - 1);
2367 } else {
2368 if (bits_ == nullptr) {
2369 CreateBits();
2370 }
2371 if (in_process_) {
2372 if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2373 bits_->DelayRemoveValue(v);
2374 }
2375 } else {
2376 if (bits_->RemoveValue(v)) {
2377 Push();
2378 }
2379 }
2380 }
2381}
2382
2383void DomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2384 if (l <= min_.Value()) {
2385 SetMin(u + 1);
2386 } else if (u >= max_.Value()) {
2387 SetMax(l - 1);
2388 } else {
2389 for (int64_t v = l; v <= u; ++v) {
2390 RemoveValue(v);
2391 }
2392 }
2393}
2394
2395void DomainIntVar::CreateBits() {
2396 solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2397 if (max_.Value() - min_.Value() < 64) {
2398 bits_ = solver()->RevAlloc(
2399 new SmallBitSet(solver(), min_.Value(), max_.Value()));
2400 } else {
2401 bits_ = solver()->RevAlloc(
2402 new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2403 }
2404}
2405
2406void DomainIntVar::CleanInProcess() {
2407 in_process_ = false;
2408 if (bits_ != nullptr) {
2409 bits_->ClearHoles();
2410 }
2411}
2412
2413void DomainIntVar::Push() {
2414 const bool in_process = in_process_;
2415 EnqueueVar(&handler_);
2416 CHECK_EQ(in_process, in_process_);
2417}
2418
2419void DomainIntVar::Process() {
2420 CHECK(!in_process_);
2421 in_process_ = true;
2422 if (bits_ != nullptr) {
2423 bits_->ClearRemovedValues();
2424 }
2425 set_variable_to_clean_on_fail(this);
2426 new_min_ = min_.Value();
2427 new_max_ = max_.Value();
2428 const bool is_bound = min_.Value() == max_.Value();
2429 const bool range_changed =
2430 min_.Value() != OldMin() || max_.Value() != OldMax();
2431 // Process immediate demons.
2432 if (is_bound) {
2433 ExecuteAll(bound_demons_);
2434 }
2435 if (range_changed) {
2436 ExecuteAll(range_demons_);
2437 }
2438 ExecuteAll(domain_demons_);
2439
2440 // Process delayed demons.
2441 if (is_bound) {
2442 EnqueueAll(delayed_bound_demons_);
2443 }
2444 if (range_changed) {
2445 EnqueueAll(delayed_range_demons_);
2446 }
2447 EnqueueAll(delayed_domain_demons_);
2448
2449 // Everything went well if we arrive here. Let's clean the variable.
2450 set_variable_to_clean_on_fail(nullptr);
2451 CleanInProcess();
2452 old_min_ = min_.Value();
2453 old_max_ = max_.Value();
2454 if (min_.Value() < new_min_) {
2455 SetMin(new_min_);
2456 }
2457 if (max_.Value() > new_max_) {
2458 SetMax(new_max_);
2459 }
2460 if (bits_ != nullptr) {
2461 bits_->ApplyRemovedValues(this);
2462 }
2463}
2464
2465template <typename T>
2466T* CondRevAlloc(Solver* solver, bool reversible, T* object) {
2467 return reversible ? solver->RevAlloc(object) : object;
2468}
2469
2470IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2471 return CondRevAlloc(solver(), reversible, new DomainIntVarHoleIterator(this));
2472}
2473
2474IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2475 return CondRevAlloc(solver(), reversible,
2476 new DomainIntVarDomainIterator(this, reversible));
2477}
2478
2479std::string DomainIntVar::DebugString() const {
2480 std::string out;
2481 const std::string& var_name = name();
2482 if (!var_name.empty()) {
2483 out = var_name + "(";
2484 } else {
2485 out = "DomainIntVar(";
2486 }
2487 if (min_.Value() == max_.Value()) {
2488 absl::StrAppendFormat(&out, "%d", min_.Value());
2489 } else if (bits_ != nullptr) {
2490 out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2491 } else {
2492 absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2493 }
2494 out += ")";
2495 return out;
2496}
2497
2498// ----- Real Boolean Var -----
2499
2500class ConcreteBooleanVar : public BooleanVar {
2501 public:
2502 // Utility classes
2503 class Handler : public Demon {
2504 public:
2505 explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2506 ~Handler() override {}
2507 void Run(Solver* const s) override {
2508 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2509 var_->Process();
2510 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2511 }
2512 Solver::DemonPriority priority() const override {
2513 return Solver::VAR_PRIORITY;
2514 }
2515 std::string DebugString() const override {
2516 return absl::StrFormat("Handler(%s)", var_->DebugString());
2517 }
2518
2519 private:
2520 ConcreteBooleanVar* const var_;
2521 };
2522
2523 ConcreteBooleanVar(Solver* const s, const std::string& name)
2524 : BooleanVar(s, name), handler_(this) {}
2525
2526 ~ConcreteBooleanVar() override {}
2527
2528 void SetValue(int64_t v) override {
2529 if (value_ == kUnboundBooleanVarValue) {
2530 if ((v & 0xfffffffffffffffe) == 0) {
2531 InternalSaveBooleanVarValue(solver(), this);
2532 value_ = static_cast<int>(v);
2533 EnqueueVar(&handler_);
2534 return;
2535 }
2536 } else if (v == value_) {
2537 return;
2538 }
2539 solver()->Fail();
2540 }
2541
2542 void Process() {
2543 DCHECK_NE(value_, kUnboundBooleanVarValue);
2544 ExecuteAll(bound_demons_);
2545 for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2546 ++it) {
2547 EnqueueDelayedDemon(*it);
2548 }
2549 }
2550
2551 int64_t OldMin() const override { return 0LL; }
2552 int64_t OldMax() const override { return 1LL; }
2553 void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2554
2555 private:
2556 Handler handler_;
2557};
2558
2559// ----- IntConst -----
2560
2561class IntConst : public IntVar {
2562 public:
2563 IntConst(Solver* const s, int64_t value, const std::string& name = "")
2564 : IntVar(s, name), value_(value) {}
2565 ~IntConst() override {}
2566
2567 int64_t Min() const override { return value_; }
2568 void SetMin(int64_t m) override {
2569 if (m > value_) {
2570 solver()->Fail();
2571 }
2572 }
2573 int64_t Max() const override { return value_; }
2574 void SetMax(int64_t m) override {
2575 if (m < value_) {
2576 solver()->Fail();
2577 }
2578 }
2579 void SetRange(int64_t l, int64_t u) override {
2580 if (l > value_ || u < value_) {
2581 solver()->Fail();
2582 }
2583 }
2584 void SetValue(int64_t v) override {
2585 if (v != value_) {
2586 solver()->Fail();
2587 }
2588 }
2589 bool Bound() const override { return true; }
2590 int64_t Value() const override { return value_; }
2591 void RemoveValue(int64_t v) override {
2592 if (v == value_) {
2593 solver()->Fail();
2594 }
2595 }
2596 void RemoveInterval(int64_t l, int64_t u) override {
2597 if (l <= value_ && value_ <= u) {
2598 solver()->Fail();
2599 }
2600 }
2601 void WhenBound(Demon* d) override {}
2602 void WhenRange(Demon* d) override {}
2603 void WhenDomain(Demon* d) override {}
2604 uint64_t Size() const override { return 1; }
2605 bool Contains(int64_t v) const override { return (v == value_); }
2606 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2607 return CondRevAlloc(solver(), reversible, new EmptyIterator());
2608 }
2609 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2610 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
2611 }
2612 int64_t OldMin() const override { return value_; }
2613 int64_t OldMax() const override { return value_; }
2614 std::string DebugString() const override {
2615 std::string out;
2616 if (solver()->HasName(this)) {
2617 const std::string& var_name = name();
2618 absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2619 } else {
2620 absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2621 }
2622 return out;
2623 }
2624
2625 int VarType() const override { return CONST_VAR; }
2626
2627 IntVar* IsEqual(int64_t constant) override {
2628 if (constant == value_) {
2629 return solver()->MakeIntConst(1);
2630 } else {
2631 return solver()->MakeIntConst(0);
2632 }
2633 }
2634
2635 IntVar* IsDifferent(int64_t constant) override {
2636 if (constant == value_) {
2637 return solver()->MakeIntConst(0);
2638 } else {
2639 return solver()->MakeIntConst(1);
2640 }
2641 }
2642
2643 IntVar* IsGreaterOrEqual(int64_t constant) override {
2644 return solver()->MakeIntConst(value_ >= constant);
2645 }
2646
2647 IntVar* IsLessOrEqual(int64_t constant) override {
2648 return solver()->MakeIntConst(value_ <= constant);
2649 }
2650
2651 std::string name() const override {
2652 if (solver()->HasName(this)) {
2653 return PropagationBaseObject::name();
2654 } else {
2655 return absl::StrCat(value_);
2656 }
2657 }
2658
2659 private:
2660 int64_t value_;
2661};
2662
2663// ----- x + c variable, optimized case -----
2664
2665class PlusCstVar : public IntVar {
2666 public:
2667 PlusCstVar(Solver* const s, IntVar* v, int64_t c)
2668 : IntVar(s), var_(v), cst_(c) {}
2669
2670 ~PlusCstVar() override {}
2671
2672 void WhenRange(Demon* d) override { var_->WhenRange(d); }
2673
2674 void WhenBound(Demon* d) override { var_->WhenBound(d); }
2675
2676 void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2677
2678 int64_t OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2679
2680 int64_t OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2681
2682 std::string DebugString() const override {
2683 if (HasName()) {
2684 return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2685 } else {
2686 return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2687 }
2688 }
2689
2690 int VarType() const override { return VAR_ADD_CST; }
2691
2692 void Accept(ModelVisitor* const visitor) const override {
2693 visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2694 var_);
2695 }
2696
2697 IntVar* IsEqual(int64_t constant) override {
2698 return var_->IsEqual(constant - cst_);
2699 }
2700
2701 IntVar* IsDifferent(int64_t constant) override {
2702 return var_->IsDifferent(constant - cst_);
2703 }
2704
2705 IntVar* IsGreaterOrEqual(int64_t constant) override {
2706 return var_->IsGreaterOrEqual(constant - cst_);
2707 }
2708
2709 IntVar* IsLessOrEqual(int64_t constant) override {
2710 return var_->IsLessOrEqual(constant - cst_);
2711 }
2712
2713 IntVar* SubVar() const { return var_; }
2714
2715 int64_t Constant() const { return cst_; }
2716
2717 protected:
2718 IntVar* const var_;
2719 const int64_t cst_;
2720};
2721
2722class PlusCstIntVar : public PlusCstVar {
2723 public:
2724 class PlusCstIntVarIterator : public UnaryIterator {
2725 public:
2726 PlusCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2727 : UnaryIterator(v, hole, rev), cst_(c) {}
2728
2729 ~PlusCstIntVarIterator() override {}
2730
2731 int64_t Value() const override { return iterator_->Value() + cst_; }
2732
2733 private:
2734 const int64_t cst_;
2735 };
2736
2737 PlusCstIntVar(Solver* const s, IntVar* v, int64_t c) : PlusCstVar(s, v, c) {}
2738
2739 ~PlusCstIntVar() override {}
2740
2741 int64_t Min() const override { return var_->Min() + cst_; }
2742
2743 void SetMin(int64_t m) override { var_->SetMin(CapSub(m, cst_)); }
2744
2745 int64_t Max() const override { return var_->Max() + cst_; }
2746
2747 void SetMax(int64_t m) override { var_->SetMax(CapSub(m, cst_)); }
2748
2749 void SetRange(int64_t l, int64_t u) override {
2750 var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2751 }
2752
2753 void SetValue(int64_t v) override { var_->SetValue(v - cst_); }
2754
2755 int64_t Value() const override { return var_->Value() + cst_; }
2756
2757 bool Bound() const override { return var_->Bound(); }
2758
2759 void RemoveValue(int64_t v) override { var_->RemoveValue(v - cst_); }
2760
2761 void RemoveInterval(int64_t l, int64_t u) override {
2762 var_->RemoveInterval(l - cst_, u - cst_);
2763 }
2764
2765 uint64_t Size() const override { return var_->Size(); }
2766
2767 bool Contains(int64_t v) const override { return var_->Contains(v - cst_); }
2768
2769 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2770 return CondRevAlloc(
2771 solver(), reversible,
2772 new PlusCstIntVarIterator(var_, cst_, true, reversible));
2773 }
2774 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2775 return CondRevAlloc(
2776 solver(), reversible,
2777 new PlusCstIntVarIterator(var_, cst_, false, reversible));
2778 }
2779};
2780
2781class PlusCstDomainIntVar : public PlusCstVar {
2782 public:
2783 class PlusCstDomainIntVarIterator : public UnaryIterator {
2784 public:
2785 PlusCstDomainIntVarIterator(const IntVar* const v, int64_t c, bool hole,
2786 bool reversible)
2787 : UnaryIterator(v, hole, reversible), cst_(c) {}
2788
2789 ~PlusCstDomainIntVarIterator() override {}
2790
2791 int64_t Value() const override { return iterator_->Value() + cst_; }
2792
2793 private:
2794 const int64_t cst_;
2795 };
2796
2797 PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64_t c)
2798 : PlusCstVar(s, v, c) {}
2799
2800 ~PlusCstDomainIntVar() override {}
2801
2802 int64_t Min() const override;
2803 void SetMin(int64_t m) override;
2804 int64_t Max() const override;
2805 void SetMax(int64_t m) override;
2806 void SetRange(int64_t l, int64_t u) override;
2807 void SetValue(int64_t v) override;
2808 bool Bound() const override;
2809 int64_t Value() const override;
2810 void RemoveValue(int64_t v) override;
2811 void RemoveInterval(int64_t l, int64_t u) override;
2812 uint64_t Size() const override;
2813 bool Contains(int64_t v) const override;
2814
2815 DomainIntVar* domain_int_var() const {
2816 return reinterpret_cast<DomainIntVar*>(var_);
2817 }
2818
2819 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2820 return CondRevAlloc(
2821 solver(), reversible,
2822 new PlusCstDomainIntVarIterator(var_, cst_, true, reversible));
2823 }
2824 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2825 return CondRevAlloc(
2826 solver(), reversible,
2827 new PlusCstDomainIntVarIterator(var_, cst_, false, reversible));
2828 }
2829};
2830
2831int64_t PlusCstDomainIntVar::Min() const {
2832 return domain_int_var()->min_.Value() + cst_;
2833}
2834
2835void PlusCstDomainIntVar::SetMin(int64_t m) {
2836 domain_int_var()->DomainIntVar::SetMin(CapSub(m, cst_));
2837}
2838
2839int64_t PlusCstDomainIntVar::Max() const {
2840 return domain_int_var()->max_.Value() + cst_;
2841}
2842
2843void PlusCstDomainIntVar::SetMax(int64_t m) {
2844 domain_int_var()->DomainIntVar::SetMax(CapSub(m, cst_));
2845}
2846
2847void PlusCstDomainIntVar::SetRange(int64_t l, int64_t u) {
2848 domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2849}
2850
2851void PlusCstDomainIntVar::SetValue(int64_t v) {
2852 domain_int_var()->DomainIntVar::SetValue(v - cst_);
2853}
2854
2855bool PlusCstDomainIntVar::Bound() const {
2856 return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2857}
2858
2859int64_t PlusCstDomainIntVar::Value() const {
2860 CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2861 << " variable is not bound";
2862 return domain_int_var()->min_.Value() + cst_;
2863}
2864
2865void PlusCstDomainIntVar::RemoveValue(int64_t v) {
2866 domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2867}
2868
2869void PlusCstDomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2870 domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2871}
2872
2873uint64_t PlusCstDomainIntVar::Size() const {
2874 return domain_int_var()->DomainIntVar::Size();
2875}
2876
2877bool PlusCstDomainIntVar::Contains(int64_t v) const {
2878 return domain_int_var()->DomainIntVar::Contains(v - cst_);
2879}
2880
2881// c - x variable, optimized case
2882
2883class SubCstIntVar : public IntVar {
2884 public:
2885 class SubCstIntVarIterator : public UnaryIterator {
2886 public:
2887 SubCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2888 : UnaryIterator(v, hole, rev), cst_(c) {}
2889 ~SubCstIntVarIterator() override {}
2890
2891 int64_t Value() const override { return cst_ - iterator_->Value(); }
2892
2893 private:
2894 const int64_t cst_;
2895 };
2896
2897 SubCstIntVar(Solver* s, IntVar* v, int64_t c);
2898 ~SubCstIntVar() override;
2899
2900 int64_t Min() const override;
2901 void SetMin(int64_t m) override;
2902 int64_t Max() const override;
2903 void SetMax(int64_t m) override;
2904 void SetRange(int64_t l, int64_t u) override;
2905 void SetValue(int64_t v) override;
2906 bool Bound() const override;
2907 int64_t Value() const override;
2908 void RemoveValue(int64_t v) override;
2909 void RemoveInterval(int64_t l, int64_t u) override;
2910 uint64_t Size() const override;
2911 bool Contains(int64_t v) const override;
2912 void WhenRange(Demon* d) override;
2913 void WhenBound(Demon* d) override;
2914 void WhenDomain(Demon* d) override;
2915 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2916 return CondRevAlloc(solver(), reversible,
2917 new SubCstIntVarIterator(var_, cst_, true, reversible));
2918 }
2919 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2920 return CondRevAlloc(
2921 solver(), reversible,
2922 new SubCstIntVarIterator(var_, cst_, false, reversible));
2923 }
2924 int64_t OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2925 int64_t OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2926 std::string DebugString() const override;
2927 std::string name() const override;
2928 int VarType() const override { return CST_SUB_VAR; }
2929
2930 void Accept(ModelVisitor* const visitor) const override {
2931 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2932 cst_, var_);
2933 }
2934
2935 IntVar* IsEqual(int64_t constant) override {
2936 return var_->IsEqual(cst_ - constant);
2937 }
2938
2939 IntVar* IsDifferent(int64_t constant) override {
2940 return var_->IsDifferent(cst_ - constant);
2941 }
2942
2943 IntVar* IsGreaterOrEqual(int64_t constant) override {
2944 return var_->IsLessOrEqual(cst_ - constant);
2945 }
2946
2947 IntVar* IsLessOrEqual(int64_t constant) override {
2948 return var_->IsGreaterOrEqual(cst_ - constant);
2949 }
2950
2951 IntVar* SubVar() const { return var_; }
2952 int64_t Constant() const { return cst_; }
2953
2954 private:
2955 IntVar* const var_;
2956 const int64_t cst_;
2957};
2958
2959SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64_t c)
2960 : IntVar(s), var_(v), cst_(c) {}
2961
2962SubCstIntVar::~SubCstIntVar() {}
2963
2964int64_t SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2965
2966void SubCstIntVar::SetMin(int64_t m) { var_->SetMax(CapSub(cst_, m)); }
2967
2968int64_t SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2969
2970void SubCstIntVar::SetMax(int64_t m) { var_->SetMin(CapSub(cst_, m)); }
2971
2972void SubCstIntVar::SetRange(int64_t l, int64_t u) {
2973 var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2974}
2975
2976void SubCstIntVar::SetValue(int64_t v) { var_->SetValue(cst_ - v); }
2977
2978bool SubCstIntVar::Bound() const { return var_->Bound(); }
2979
2980void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2981
2982int64_t SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2983
2984void SubCstIntVar::RemoveValue(int64_t v) { var_->RemoveValue(cst_ - v); }
2985
2986void SubCstIntVar::RemoveInterval(int64_t l, int64_t u) {
2987 var_->RemoveInterval(cst_ - u, cst_ - l);
2988}
2989
2990void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2991
2992void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2993
2994uint64_t SubCstIntVar::Size() const { return var_->Size(); }
2995
2996bool SubCstIntVar::Contains(int64_t v) const {
2997 return var_->Contains(cst_ - v);
2998}
2999
3000std::string SubCstIntVar::DebugString() const {
3001 if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3002 return absl::StrFormat("Not(%s)", var_->DebugString());
3003 } else {
3004 return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
3005 }
3006}
3007
3008std::string SubCstIntVar::name() const {
3009 if (solver()->HasName(this)) {
3010 return PropagationBaseObject::name();
3011 } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3012 return absl::StrFormat("Not(%s)", var_->name());
3013 } else {
3014 return absl::StrFormat("(%d - %s)", cst_, var_->name());
3015 }
3016}
3017
3018// -x variable, optimized case
3019
3020class OppIntVar : public IntVar {
3021 public:
3022 class OppIntVarIterator : public UnaryIterator {
3023 public:
3024 OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3025 : UnaryIterator(v, hole, reversible) {}
3026 ~OppIntVarIterator() override {}
3027
3028 int64_t Value() const override { return -iterator_->Value(); }
3029 };
3030
3031 OppIntVar(Solver* s, IntVar* v);
3032 ~OppIntVar() override;
3033
3034 int64_t Min() const override;
3035 void SetMin(int64_t m) override;
3036 int64_t Max() const override;
3037 void SetMax(int64_t m) override;
3038 void SetRange(int64_t l, int64_t u) override;
3039 void SetValue(int64_t v) override;
3040 bool Bound() const override;
3041 int64_t Value() const override;
3042 void RemoveValue(int64_t v) override;
3043 void RemoveInterval(int64_t l, int64_t u) override;
3044 uint64_t Size() const override;
3045 bool Contains(int64_t v) const override;
3046 void WhenRange(Demon* d) override;
3047 void WhenBound(Demon* d) override;
3048 void WhenDomain(Demon* d) override;
3049 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3050 return CondRevAlloc(solver(), reversible,
3051 new OppIntVarIterator(var_, true, reversible));
3052 }
3053 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3054 return CondRevAlloc(solver(), reversible,
3055 new OppIntVarIterator(var_, false, reversible));
3056 }
3057 int64_t OldMin() const override { return CapOpp(var_->OldMax()); }
3058 int64_t OldMax() const override { return CapOpp(var_->OldMin()); }
3059 std::string DebugString() const override;
3060 int VarType() const override { return OPP_VAR; }
3061
3062 void Accept(ModelVisitor* const visitor) const override {
3063 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3064 var_);
3065 }
3066
3067 IntVar* IsEqual(int64_t constant) override {
3068 return var_->IsEqual(-constant);
3069 }
3070
3071 IntVar* IsDifferent(int64_t constant) override {
3072 return var_->IsDifferent(-constant);
3073 }
3074
3075 IntVar* IsGreaterOrEqual(int64_t constant) override {
3076 return var_->IsLessOrEqual(-constant);
3077 }
3078
3079 IntVar* IsLessOrEqual(int64_t constant) override {
3080 return var_->IsGreaterOrEqual(-constant);
3081 }
3082
3083 IntVar* SubVar() const { return var_; }
3084
3085 private:
3086 IntVar* const var_;
3087};
3088
3089OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3090
3091OppIntVar::~OppIntVar() {}
3092
3093int64_t OppIntVar::Min() const { return -var_->Max(); }
3094
3095void OppIntVar::SetMin(int64_t m) { var_->SetMax(CapOpp(m)); }
3096
3097int64_t OppIntVar::Max() const { return -var_->Min(); }
3098
3099void OppIntVar::SetMax(int64_t m) { var_->SetMin(CapOpp(m)); }
3100
3101void OppIntVar::SetRange(int64_t l, int64_t u) {
3102 var_->SetRange(CapOpp(u), CapOpp(l));
3103}
3104
3105void OppIntVar::SetValue(int64_t v) { var_->SetValue(CapOpp(v)); }
3106
3107bool OppIntVar::Bound() const { return var_->Bound(); }
3108
3109void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3110
3111int64_t OppIntVar::Value() const { return -var_->Value(); }
3112
3113void OppIntVar::RemoveValue(int64_t v) { var_->RemoveValue(-v); }
3114
3115void OppIntVar::RemoveInterval(int64_t l, int64_t u) {
3116 var_->RemoveInterval(-u, -l);
3117}
3118
3119void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3120
3121void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3122
3123uint64_t OppIntVar::Size() const { return var_->Size(); }
3124
3125bool OppIntVar::Contains(int64_t v) const { return var_->Contains(-v); }
3126
3127std::string OppIntVar::DebugString() const {
3128 return absl::StrFormat("-(%s)", var_->DebugString());
3129}
3130
3131// ----- Utility functions -----
3132
3133// x * c variable, optimized case
3134
3135class TimesCstIntVar : public IntVar {
3136 public:
3137 TimesCstIntVar(Solver* const s, IntVar* v, int64_t c)
3138 : IntVar(s), var_(v), cst_(c) {}
3139 ~TimesCstIntVar() override {}
3140
3141 IntVar* SubVar() const { return var_; }
3142 int64_t Constant() const { return cst_; }
3143
3144 void Accept(ModelVisitor* const visitor) const override {
3145 visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3146 var_);
3147 }
3148
3149 IntVar* IsEqual(int64_t constant) override {
3150 if (constant % cst_ == 0) {
3151 return var_->IsEqual(constant / cst_);
3152 } else {
3153 return solver()->MakeIntConst(0);
3154 }
3155 }
3156
3157 IntVar* IsDifferent(int64_t constant) override {
3158 if (constant % cst_ == 0) {
3159 return var_->IsDifferent(constant / cst_);
3160 } else {
3161 return solver()->MakeIntConst(1);
3162 }
3163 }
3164
3165 IntVar* IsGreaterOrEqual(int64_t constant) override {
3166 if (cst_ > 0) {
3167 return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3168 } else {
3169 return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3170 }
3171 }
3172
3173 IntVar* IsLessOrEqual(int64_t constant) override {
3174 if (cst_ > 0) {
3175 return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3176 } else {
3177 return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3178 }
3179 }
3180
3181 std::string DebugString() const override {
3182 return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3183 }
3184
3185 int VarType() const override { return VAR_TIMES_CST; }
3186
3187 protected:
3188 IntVar* const var_;
3189 const int64_t cst_;
3190};
3191
3192class TimesPosCstIntVar : public TimesCstIntVar {
3193 public:
3194 class TimesPosCstIntVarIterator : public UnaryIterator {
3195 public:
3196 TimesPosCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3197 bool reversible)
3198 : UnaryIterator(v, hole, reversible), cst_(c) {}
3199 ~TimesPosCstIntVarIterator() override {}
3200
3201 int64_t Value() const override { return iterator_->Value() * cst_; }
3202
3203 private:
3204 const int64_t cst_;
3205 };
3206
3207 TimesPosCstIntVar(Solver* s, IntVar* v, int64_t c);
3208 ~TimesPosCstIntVar() override;
3209
3210 int64_t Min() const override;
3211 void SetMin(int64_t m) override;
3212 int64_t Max() const override;
3213 void SetMax(int64_t m) override;
3214 void SetRange(int64_t l, int64_t u) override;
3215 void SetValue(int64_t v) override;
3216 bool Bound() const override;
3217 int64_t Value() const override;
3218 void RemoveValue(int64_t v) override;
3219 void RemoveInterval(int64_t l, int64_t u) override;
3220 uint64_t Size() const override;
3221 bool Contains(int64_t v) const override;
3222 void WhenRange(Demon* d) override;
3223 void WhenBound(Demon* d) override;
3224 void WhenDomain(Demon* d) override;
3225 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3226 return CondRevAlloc(
3227 solver(), reversible,
3228 new TimesPosCstIntVarIterator(var_, cst_, true, reversible));
3229 }
3230 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3231 return CondRevAlloc(
3232 solver(), reversible,
3233 new TimesPosCstIntVarIterator(var_, cst_, false, reversible));
3234 }
3235 int64_t OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3236 int64_t OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3237};
3238
3239// ----- TimesPosCstIntVar -----
3240
3241TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c)
3242 : TimesCstIntVar(s, v, c) {}
3243
3244TimesPosCstIntVar::~TimesPosCstIntVar() {}
3245
3246int64_t TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3247
3248void TimesPosCstIntVar::SetMin(int64_t m) {
3249 if (m != std::numeric_limits<int64_t>::min()) {
3250 var_->SetMin(PosIntDivUp(m, cst_));
3251 }
3252}
3253
3254int64_t TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3255
3256void TimesPosCstIntVar::SetMax(int64_t m) {
3257 if (m != std::numeric_limits<int64_t>::max()) {
3258 var_->SetMax(PosIntDivDown(m, cst_));
3259 }
3260}
3261
3262void TimesPosCstIntVar::SetRange(int64_t l, int64_t u) {
3263 var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3264}
3265
3266void TimesPosCstIntVar::SetValue(int64_t v) {
3267 if (v % cst_ != 0) {
3268 solver()->Fail();
3269 }
3270 var_->SetValue(v / cst_);
3271}
3272
3273bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3274
3275void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3276
3277int64_t TimesPosCstIntVar::Value() const {
3278 return CapProd(var_->Value(), cst_);
3279}
3280
3281void TimesPosCstIntVar::RemoveValue(int64_t v) {
3282 if (v % cst_ == 0) {
3283 var_->RemoveValue(v / cst_);
3284 }
3285}
3286
3287void TimesPosCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3288 for (int64_t v = l; v <= u; ++v) {
3289 RemoveValue(v);
3290 }
3291 // TODO(user) : Improve me
3292}
3293
3294void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3295
3296void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3297
3298uint64_t TimesPosCstIntVar::Size() const { return var_->Size(); }
3299
3300bool TimesPosCstIntVar::Contains(int64_t v) const {
3301 return (v % cst_ == 0 && var_->Contains(v / cst_));
3302}
3303
3304// b * c variable, optimized case
3305
3306class TimesPosCstBoolVar : public TimesCstIntVar {
3307 public:
3308 class TimesPosCstBoolVarIterator : public UnaryIterator {
3309 public:
3310 // TODO(user) : optimize this.
3311 TimesPosCstBoolVarIterator(const IntVar* const v, int64_t c, bool hole,
3312 bool reversible)
3313 : UnaryIterator(v, hole, reversible), cst_(c) {}
3314 ~TimesPosCstBoolVarIterator() override {}
3315
3316 int64_t Value() const override { return iterator_->Value() * cst_; }
3317
3318 private:
3319 const int64_t cst_;
3320 };
3321
3322 TimesPosCstBoolVar(Solver* s, BooleanVar* v, int64_t c);
3323 ~TimesPosCstBoolVar() override;
3324
3325 int64_t Min() const override;
3326 void SetMin(int64_t m) override;
3327 int64_t Max() const override;
3328 void SetMax(int64_t m) override;
3329 void SetRange(int64_t l, int64_t u) override;
3330 void SetValue(int64_t v) override;
3331 bool Bound() const override;
3332 int64_t Value() const override;
3333 void RemoveValue(int64_t v) override;
3334 void RemoveInterval(int64_t l, int64_t u) override;
3335 uint64_t Size() const override;
3336 bool Contains(int64_t v) const override;
3337 void WhenRange(Demon* d) override;
3338 void WhenBound(Demon* d) override;
3339 void WhenDomain(Demon* d) override;
3340 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3341 return CondRevAlloc(solver(), reversible, new EmptyIterator());
3342 }
3343 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3344 return CondRevAlloc(
3345 solver(), reversible,
3346 new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3347 }
3348 int64_t OldMin() const override { return 0; }
3349 int64_t OldMax() const override { return cst_; }
3350
3351 BooleanVar* boolean_var() const {
3352 return reinterpret_cast<BooleanVar*>(var_);
3353 }
3354};
3355
3356// ----- TimesPosCstBoolVar -----
3357
3358TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v,
3359 int64_t c)
3360 : TimesCstIntVar(s, v, c) {}
3361
3362TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3363
3364int64_t TimesPosCstBoolVar::Min() const {
3365 return (boolean_var()->RawValue() == 1) * cst_;
3366}
3367
3368void TimesPosCstBoolVar::SetMin(int64_t m) {
3369 if (m > cst_) {
3370 solver()->Fail();
3371 } else if (m > 0) {
3372 boolean_var()->SetMin(1);
3373 }
3374}
3375
3376int64_t TimesPosCstBoolVar::Max() const {
3377 return (boolean_var()->RawValue() != 0) * cst_;
3378}
3379
3380void TimesPosCstBoolVar::SetMax(int64_t m) {
3381 if (m < 0) {
3382 solver()->Fail();
3383 } else if (m < cst_) {
3384 boolean_var()->SetMax(0);
3385 }
3386}
3387
3388void TimesPosCstBoolVar::SetRange(int64_t l, int64_t u) {
3389 if (u < 0 || l > cst_ || l > u) {
3390 solver()->Fail();
3391 }
3392 if (l > 0) {
3393 boolean_var()->SetMin(1);
3394 } else if (u < cst_) {
3395 boolean_var()->SetMax(0);
3396 }
3397}
3398
3399void TimesPosCstBoolVar::SetValue(int64_t v) {
3400 if (v == 0) {
3401 boolean_var()->SetValue(0);
3402 } else if (v == cst_) {
3403 boolean_var()->SetValue(1);
3404 } else {
3405 solver()->Fail();
3406 }
3407}
3408
3409bool TimesPosCstBoolVar::Bound() const {
3410 return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3411}
3412
3413void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3414
3415int64_t TimesPosCstBoolVar::Value() const {
3416 CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3417 << " variable is not bound";
3418 return boolean_var()->RawValue() * cst_;
3419}
3420
3421void TimesPosCstBoolVar::RemoveValue(int64_t v) {
3422 if (v == 0) {
3423 boolean_var()->RemoveValue(0);
3424 } else if (v == cst_) {
3425 boolean_var()->RemoveValue(1);
3426 }
3427}
3428
3429void TimesPosCstBoolVar::RemoveInterval(int64_t l, int64_t u) {
3430 if (l <= 0 && u >= 0) {
3431 boolean_var()->RemoveValue(0);
3432 }
3433 if (l <= cst_ && u >= cst_) {
3434 boolean_var()->RemoveValue(1);
3435 }
3436}
3437
3438void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3439
3440void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3441
3442uint64_t TimesPosCstBoolVar::Size() const {
3443 return (1 +
3444 (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3445}
3446
3447bool TimesPosCstBoolVar::Contains(int64_t v) const {
3448 if (v == 0) {
3449 return boolean_var()->RawValue() != 1;
3450 } else if (v == cst_) {
3451 return boolean_var()->RawValue() != 0;
3452 }
3453 return false;
3454}
3455
3456// TimesNegCstIntVar
3457
3458class TimesNegCstIntVar : public TimesCstIntVar {
3459 public:
3460 class TimesNegCstIntVarIterator : public UnaryIterator {
3461 public:
3462 TimesNegCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3463 bool reversible)
3464 : UnaryIterator(v, hole, reversible), cst_(c) {}
3465 ~TimesNegCstIntVarIterator() override {}
3466
3467 int64_t Value() const override { return iterator_->Value() * cst_; }
3468
3469 private:
3470 const int64_t cst_;
3471 };
3472
3473 TimesNegCstIntVar(Solver* s, IntVar* v, int64_t c);
3474 ~TimesNegCstIntVar() override;
3475
3476 int64_t Min() const override;
3477 void SetMin(int64_t m) override;
3478 int64_t Max() const override;
3479 void SetMax(int64_t m) override;
3480 void SetRange(int64_t l, int64_t u) override;
3481 void SetValue(int64_t v) override;
3482 bool Bound() const override;
3483 int64_t Value() const override;
3484 void RemoveValue(int64_t v) override;
3485 void RemoveInterval(int64_t l, int64_t u) override;
3486 uint64_t Size() const override;
3487 bool Contains(int64_t v) const override;
3488 void WhenRange(Demon* d) override;
3489 void WhenBound(Demon* d) override;
3490 void WhenDomain(Demon* d) override;
3491 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3492 return CondRevAlloc(
3493 solver(), reversible,
3494 new TimesNegCstIntVarIterator(var_, cst_, true, reversible));
3495 }
3496 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3497 return CondRevAlloc(
3498 solver(), reversible,
3499 new TimesNegCstIntVarIterator(var_, cst_, false, reversible));
3500 }
3501 int64_t OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3502 int64_t OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3503};
3504
3505// ----- TimesNegCstIntVar -----
3506
3507TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c)
3508 : TimesCstIntVar(s, v, c) {}
3509
3510TimesNegCstIntVar::~TimesNegCstIntVar() {}
3511
3512int64_t TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3513
3514void TimesNegCstIntVar::SetMin(int64_t m) {
3515 if (m != std::numeric_limits<int64_t>::min()) {
3516 var_->SetMax(PosIntDivDown(-m, -cst_));
3517 }
3518}
3519
3520int64_t TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3521
3522void TimesNegCstIntVar::SetMax(int64_t m) {
3523 if (m != std::numeric_limits<int64_t>::max()) {
3524 var_->SetMin(PosIntDivUp(-m, -cst_));
3525 }
3526}
3527
3528void TimesNegCstIntVar::SetRange(int64_t l, int64_t u) {
3529 var_->SetRange(PosIntDivUp(CapOpp(u), CapOpp(cst_)),
3530 PosIntDivDown(CapOpp(l), CapOpp(cst_)));
3531}
3532
3533void TimesNegCstIntVar::SetValue(int64_t v) {
3534 if (v % cst_ != 0) {
3535 solver()->Fail();
3536 }
3537 var_->SetValue(v / cst_);
3538}
3539
3540bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3541
3542void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3543
3544int64_t TimesNegCstIntVar::Value() const {
3545 return CapProd(var_->Value(), cst_);
3546}
3547
3548void TimesNegCstIntVar::RemoveValue(int64_t v) {
3549 if (v % cst_ == 0) {
3550 var_->RemoveValue(v / cst_);
3551 }
3552}
3553
3554void TimesNegCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3555 for (int64_t v = l; v <= u; ++v) {
3556 RemoveValue(v);
3557 }
3558 // TODO(user) : Improve me
3559}
3560
3561void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3562
3563void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3564
3565uint64_t TimesNegCstIntVar::Size() const { return var_->Size(); }
3566
3567bool TimesNegCstIntVar::Contains(int64_t v) const {
3568 return (v % cst_ == 0 && var_->Contains(v / cst_));
3569}
3570
3571// ---------- arithmetic expressions ----------
3572
3573// ----- PlusIntExpr -----
3574
3575class PlusIntExpr : public BaseIntExpr {
3576 public:
3577 PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3578 : BaseIntExpr(s), left_(l), right_(r) {}
3579
3580 ~PlusIntExpr() override {}
3581
3582 int64_t Min() const override { return left_->Min() + right_->Min(); }
3583
3584 void SetMin(int64_t m) override {
3585 if (m > left_->Min() + right_->Min()) {
3586 // Catching potential overflow.
3587 if (m > right_->Max() + left_->Max()) solver()->Fail();
3588 left_->SetMin(m - right_->Max());
3589 right_->SetMin(m - left_->Max());
3590 }
3591 }
3592
3593 void SetRange(int64_t l, int64_t u) override {
3594 const int64_t left_min = left_->Min();
3595 const int64_t right_min = right_->Min();
3596 const int64_t left_max = left_->Max();
3597 const int64_t right_max = right_->Max();
3598 if (l > left_min + right_min) {
3599 // Catching potential overflow.
3600 if (l > right_max + left_max) solver()->Fail();
3601 left_->SetMin(l - right_max);
3602 right_->SetMin(l - left_max);
3603 }
3604 if (u < left_max + right_max) {
3605 // Catching potential overflow.
3606 if (u < right_min + left_min) solver()->Fail();
3607 left_->SetMax(u - right_min);
3608 right_->SetMax(u - left_min);
3609 }
3610 }
3611
3612 int64_t Max() const override { return left_->Max() + right_->Max(); }
3613
3614 void SetMax(int64_t m) override {
3615 if (m < left_->Max() + right_->Max()) {
3616 // Catching potential overflow.
3617 if (m < right_->Min() + left_->Min()) solver()->Fail();
3618 left_->SetMax(m - right_->Min());
3619 right_->SetMax(m - left_->Min());
3620 }
3621 }
3622
3623 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3624
3625 void Range(int64_t* const mi, int64_t* const ma) override {
3626 *mi = left_->Min() + right_->Min();
3627 *ma = left_->Max() + right_->Max();
3628 }
3629
3630 std::string name() const override {
3631 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3632 }
3633
3634 std::string DebugString() const override {
3635 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3636 right_->DebugString());
3637 }
3638
3639 void WhenRange(Demon* d) override {
3640 left_->WhenRange(d);
3641 right_->WhenRange(d);
3642 }
3643
3644 void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3645 PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3646 if (casted != nullptr) {
3647 ExpandPlusIntExpr(casted->left_, subs);
3648 ExpandPlusIntExpr(casted->right_, subs);
3649 } else {
3650 subs->push_back(expr);
3651 }
3652 }
3653
3654 IntVar* CastToVar() override {
3655 if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3656 dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3657 std::vector<IntExpr*> sub_exprs;
3658 ExpandPlusIntExpr(left_, &sub_exprs);
3659 ExpandPlusIntExpr(right_, &sub_exprs);
3660 if (sub_exprs.size() >= 3) {
3661 std::vector<IntVar*> sub_vars(sub_exprs.size());
3662 for (int i = 0; i < sub_exprs.size(); ++i) {
3663 sub_vars[i] = sub_exprs[i]->Var();
3664 }
3665 return solver()->MakeSum(sub_vars)->Var();
3666 }
3667 }
3668 return BaseIntExpr::CastToVar();
3669 }
3670
3671 void Accept(ModelVisitor* const visitor) const override {
3672 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3673 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3674 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3675 right_);
3676 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3677 }
3678
3679 private:
3680 IntExpr* const left_;
3681 IntExpr* const right_;
3682};
3683
3684class SafePlusIntExpr : public BaseIntExpr {
3685 public:
3686 SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3687 : BaseIntExpr(s), left_(l), right_(r) {}
3688
3689 ~SafePlusIntExpr() override {}
3690
3691 int64_t Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3692
3693 void SetMin(int64_t m) override {
3694 left_->SetMin(CapSub(m, right_->Max()));
3695 right_->SetMin(CapSub(m, left_->Max()));
3696 }
3697
3698 void SetRange(int64_t l, int64_t u) override {
3699 const int64_t left_min = left_->Min();
3700 const int64_t right_min = right_->Min();
3701 const int64_t left_max = left_->Max();
3702 const int64_t right_max = right_->Max();
3703 if (l > CapAdd(left_min, right_min)) {
3704 left_->SetMin(CapSub(l, right_max));
3705 right_->SetMin(CapSub(l, left_max));
3706 }
3707 if (u < CapAdd(left_max, right_max)) {
3708 left_->SetMax(CapSub(u, right_min));
3709 right_->SetMax(CapSub(u, left_min));
3710 }
3711 }
3712
3713 int64_t Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3714
3715 void SetMax(int64_t m) override {
3716 left_->SetMax(CapSub(m, right_->Min()));
3717 right_->SetMax(CapSub(m, left_->Min()));
3718 }
3719
3720 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3721
3722 std::string name() const override {
3723 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3724 }
3725
3726 std::string DebugString() const override {
3727 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3728 right_->DebugString());
3729 }
3730
3731 void WhenRange(Demon* d) override {
3732 left_->WhenRange(d);
3733 right_->WhenRange(d);
3734 }
3735
3736 void Accept(ModelVisitor* const visitor) const override {
3737 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3738 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3739 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3740 right_);
3741 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3742 }
3743
3744 private:
3745 IntExpr* const left_;
3746 IntExpr* const right_;
3747};
3748
3749// ----- PlusIntCstExpr -----
3750
3751class PlusIntCstExpr : public BaseIntExpr {
3752 public:
3753 PlusIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3754 : BaseIntExpr(s), expr_(e), value_(v) {}
3755 ~PlusIntCstExpr() override {}
3756 int64_t Min() const override { return CapAdd(expr_->Min(), value_); }
3757 void SetMin(int64_t m) override { expr_->SetMin(CapSub(m, value_)); }
3758 int64_t Max() const override { return CapAdd(expr_->Max(), value_); }
3759 void SetMax(int64_t m) override { expr_->SetMax(CapSub(m, value_)); }
3760 bool Bound() const override { return (expr_->Bound()); }
3761 std::string name() const override {
3762 return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3763 }
3764 std::string DebugString() const override {
3765 return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3766 }
3767 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3768 IntVar* CastToVar() override;
3769 void Accept(ModelVisitor* const visitor) const override {
3770 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3771 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3772 expr_);
3773 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3774 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3775 }
3776
3777 private:
3778 IntExpr* const expr_;
3779 const int64_t value_;
3780};
3781
3782IntVar* PlusIntCstExpr::CastToVar() {
3783 Solver* const s = solver();
3784 IntVar* const var = expr_->Var();
3785 IntVar* cast = nullptr;
3786 if (AddOverflows(value_, expr_->Max()) ||
3787 AddOverflows(value_, expr_->Min())) {
3788 return BaseIntExpr::CastToVar();
3789 }
3790 switch (var->VarType()) {
3791 case DOMAIN_INT_VAR:
3792 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3793 s, reinterpret_cast<DomainIntVar*>(var), value_)));
3794 // FIXME: Break was inserted during fallthrough cleanup. Please check.
3795 break;
3796 default:
3797 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3798 break;
3799 }
3800 return cast;
3801}
3802
3803// ----- SubIntExpr -----
3804
3805class SubIntExpr : public BaseIntExpr {
3806 public:
3807 SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3808 : BaseIntExpr(s), left_(l), right_(r) {}
3809
3810 ~SubIntExpr() override {}
3811
3812 int64_t Min() const override { return left_->Min() - right_->Max(); }
3813
3814 void SetMin(int64_t m) override {
3815 left_->SetMin(CapAdd(m, right_->Min()));
3816 right_->SetMax(CapSub(left_->Max(), m));
3817 }
3818
3819 int64_t Max() const override { return left_->Max() - right_->Min(); }
3820
3821 void SetMax(int64_t m) override {
3822 left_->SetMax(CapAdd(m, right_->Max()));
3823 right_->SetMin(CapSub(left_->Min(), m));
3824 }
3825
3826 void Range(int64_t* mi, int64_t* ma) override {
3827 *mi = left_->Min() - right_->Max();
3828 *ma = left_->Max() - right_->Min();
3829 }
3830
3831 void SetRange(int64_t l, int64_t u) override {
3832 const int64_t left_min = left_->Min();
3833 const int64_t right_min = right_->Min();
3834 const int64_t left_max = left_->Max();
3835 const int64_t right_max = right_->Max();
3836 if (l > left_min - right_max) {
3837 left_->SetMin(CapAdd(l, right_min));
3838 right_->SetMax(CapSub(left_max, l));
3839 }
3840 if (u < left_max - right_min) {
3841 left_->SetMax(CapAdd(u, right_max));
3842 right_->SetMin(CapSub(left_min, u));
3843 }
3844 }
3845
3846 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3847
3848 std::string name() const override {
3849 return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3850 }
3851
3852 std::string DebugString() const override {
3853 return absl::StrFormat("(%s - %s)", left_->DebugString(),
3854 right_->DebugString());
3855 }
3856
3857 void WhenRange(Demon* d) override {
3858 left_->WhenRange(d);
3859 right_->WhenRange(d);
3860 }
3861
3862 void Accept(ModelVisitor* const visitor) const override {
3863 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3864 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3865 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3866 right_);
3867 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3868 }
3869
3870 IntExpr* left() const { return left_; }
3871 IntExpr* right() const { return right_; }
3872
3873 protected:
3874 IntExpr* const left_;
3875 IntExpr* const right_;
3876};
3877
3878class SafeSubIntExpr : public SubIntExpr {
3879 public:
3880 SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3881 : SubIntExpr(s, l, r) {}
3882
3883 ~SafeSubIntExpr() override {}
3884
3885 int64_t Min() const override { return CapSub(left_->Min(), right_->Max()); }
3886
3887 void SetMin(int64_t m) override {
3888 left_->SetMin(CapAdd(m, right_->Min()));
3889 right_->SetMax(CapSub(left_->Max(), m));
3890 }
3891
3892 void SetRange(int64_t l, int64_t u) override {
3893 const int64_t left_min = left_->Min();
3894 const int64_t right_min = right_->Min();
3895 const int64_t left_max = left_->Max();
3896 const int64_t right_max = right_->Max();
3897 if (l > CapSub(left_min, right_max)) {
3898 left_->SetMin(CapAdd(l, right_min));
3899 right_->SetMax(CapSub(left_max, l));
3900 }
3901 if (u < CapSub(left_max, right_min)) {
3902 left_->SetMax(CapAdd(u, right_max));
3903 right_->SetMin(CapSub(left_min, u));
3904 }
3905 }
3906
3907 void Range(int64_t* mi, int64_t* ma) override {
3908 *mi = CapSub(left_->Min(), right_->Max());
3909 *ma = CapSub(left_->Max(), right_->Min());
3910 }
3911
3912 int64_t Max() const override { return CapSub(left_->Max(), right_->Min()); }
3913
3914 void SetMax(int64_t m) override {
3915 left_->SetMax(CapAdd(m, right_->Max()));
3916 right_->SetMin(CapSub(left_->Min(), m));
3917 }
3918};
3919
3920// l - r
3921
3922// ----- SubIntCstExpr -----
3923
3924class SubIntCstExpr : public BaseIntExpr {
3925 public:
3926 SubIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3927 : BaseIntExpr(s), expr_(e), value_(v) {}
3928 ~SubIntCstExpr() override {}
3929 int64_t Min() const override { return CapSub(value_, expr_->Max()); }
3930 void SetMin(int64_t m) override { expr_->SetMax(CapSub(value_, m)); }
3931 int64_t Max() const override { return CapSub(value_, expr_->Min()); }
3932 void SetMax(int64_t m) override { expr_->SetMin(CapSub(value_, m)); }
3933 bool Bound() const override { return (expr_->Bound()); }
3934 std::string name() const override {
3935 return absl::StrFormat("(%d - %s)", value_, expr_->name());
3936 }
3937 std::string DebugString() const override {
3938 return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3939 }
3940 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3941 IntVar* CastToVar() override;
3942
3943 void Accept(ModelVisitor* const visitor) const override {
3944 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3945 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3946 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3947 expr_);
3948 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3949 }
3950
3951 private:
3952 IntExpr* const expr_;
3953 const int64_t value_;
3954};
3955
3956IntVar* SubIntCstExpr::CastToVar() {
3957 if (SubOverflows(value_, expr_->Min()) ||
3958 SubOverflows(value_, expr_->Max())) {
3959 return BaseIntExpr::CastToVar();
3960 }
3961 Solver* const s = solver();
3962 IntVar* const var =
3963 s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3964 return var;
3965}
3966
3967// ----- OppIntExpr -----
3968
3969class OppIntExpr : public BaseIntExpr {
3970 public:
3971 OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3972 ~OppIntExpr() override {}
3973 int64_t Min() const override { return (CapOpp(expr_->Max())); }
3974 void SetMin(int64_t m) override { expr_->SetMax(CapOpp(m)); }
3975 int64_t Max() const override { return (CapOpp(expr_->Min())); }
3976 void SetMax(int64_t m) override { expr_->SetMin(CapOpp(m)); }
3977 bool Bound() const override { return (expr_->Bound()); }
3978 std::string name() const override {
3979 return absl::StrFormat("(-%s)", expr_->name());
3980 }
3981 std::string DebugString() const override {
3982 return absl::StrFormat("(-%s)", expr_->DebugString());
3983 }
3984 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3985 IntVar* CastToVar() override;
3986
3987 void Accept(ModelVisitor* const visitor) const override {
3988 visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3989 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3990 expr_);
3991 visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3992 }
3993
3994 private:
3995 IntExpr* const expr_;
3996};
3997
3998IntVar* OppIntExpr::CastToVar() {
3999 Solver* const s = solver();
4000 IntVar* const var =
4001 s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
4002 return var;
4003}
4004
4005// ----- TimesIntCstExpr -----
4006
4007class TimesIntCstExpr : public BaseIntExpr {
4008 public:
4009 TimesIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4010 : BaseIntExpr(s), expr_(e), value_(v) {}
4011
4012 ~TimesIntCstExpr() override {}
4013
4014 bool Bound() const override { return (expr_->Bound()); }
4015
4016 std::string name() const override {
4017 return absl::StrFormat("(%s * %d)", expr_->name(), value_);
4018 }
4019
4020 std::string DebugString() const override {
4021 return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
4022 }
4023
4024 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4025
4026 IntExpr* Expr() const { return expr_; }
4027
4028 int64_t Constant() const { return value_; }
4029
4030 void Accept(ModelVisitor* const visitor) const override {
4031 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4032 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4033 expr_);
4034 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4035 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4036 }
4037
4038 protected:
4039 IntExpr* const expr_;
4040 const int64_t value_;
4041};
4042
4043// ----- TimesPosIntCstExpr -----
4044
4045class TimesPosIntCstExpr : public TimesIntCstExpr {
4046 public:
4047 TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4048 : TimesIntCstExpr(s, e, v) {
4049 CHECK_GT(v, 0);
4050 }
4051
4052 ~TimesPosIntCstExpr() override {}
4053
4054 int64_t Min() const override { return expr_->Min() * value_; }
4055
4056 void SetMin(int64_t m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4057
4058 int64_t Max() const override { return expr_->Max() * value_; }
4059
4060 void SetMax(int64_t m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4061
4062 IntVar* CastToVar() override {
4063 Solver* const s = solver();
4064 IntVar* var = nullptr;
4065 if (expr_->IsVar() &&
4066 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4067 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4068 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4069 } else {
4070 var = s->RegisterIntVar(
4071 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4072 }
4073 return var;
4074 }
4075};
4076
4077// This expressions adds safe arithmetic (w.r.t. overflows) compared
4078// to the previous one.
4079class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4080 public:
4081 SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4082 : TimesIntCstExpr(s, e, v) {
4083 CHECK_GT(v, 0);
4084 }
4085
4086 ~SafeTimesPosIntCstExpr() override {}
4087
4088 int64_t Min() const override { return CapProd(expr_->Min(), value_); }
4089
4090 void SetMin(int64_t m) override {
4091 if (m != std::numeric_limits<int64_t>::min()) {
4092 expr_->SetMin(PosIntDivUp(m, value_));
4093 }
4094 }
4095
4096 int64_t Max() const override { return CapProd(expr_->Max(), value_); }
4097
4098 void SetMax(int64_t m) override {
4099 if (m != std::numeric_limits<int64_t>::max()) {
4100 expr_->SetMax(PosIntDivDown(m, value_));
4101 }
4102 }
4103
4104 IntVar* CastToVar() override {
4105 Solver* const s = solver();
4106 IntVar* var = nullptr;
4107 if (expr_->IsVar() &&
4108 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4109 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4110 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4111 } else {
4112 // TODO(user): Check overflows.
4113 var = s->RegisterIntVar(
4114 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4115 }
4116 return var;
4117 }
4118};
4119
4120// ----- TimesIntNegCstExpr -----
4121
4122class TimesIntNegCstExpr : public TimesIntCstExpr {
4123 public:
4124 TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4125 : TimesIntCstExpr(s, e, v) {
4126 CHECK_LT(v, 0);
4127 }
4128
4129 ~TimesIntNegCstExpr() override {}
4130
4131 int64_t Min() const override { return CapProd(expr_->Max(), value_); }
4132
4133 void SetMin(int64_t m) override {
4134 if (m != std::numeric_limits<int64_t>::min()) {
4135 expr_->SetMax(PosIntDivDown(-m, -value_));
4136 }
4137 }
4138
4139 int64_t Max() const override { return CapProd(expr_->Min(), value_); }
4140
4141 void SetMax(int64_t m) override {
4142 if (m != std::numeric_limits<int64_t>::max()) {
4143 expr_->SetMin(PosIntDivUp(-m, -value_));
4144 }
4145 }
4146
4147 IntVar* CastToVar() override {
4148 Solver* const s = solver();
4149 IntVar* var = nullptr;
4150 var = s->RegisterIntVar(
4151 s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4152 return var;
4153 }
4154};
4155
4156// ----- Utilities for product expression -----
4157
4158// Propagates set_min on left * right, left and right >= 0.
4159void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4160 DCHECK_GE(left->Min(), 0);
4161 DCHECK_GE(right->Min(), 0);
4162 const int64_t lmax = left->Max();
4163 const int64_t rmax = right->Max();
4164 if (m > CapProd(lmax, rmax)) {
4165 left->solver()->Fail();
4166 }
4167 if (m > CapProd(left->Min(), right->Min())) {
4168 // Ok for m == 0 due to left and right being positive
4169 if (0 != rmax) {
4170 left->SetMin(PosIntDivUp(m, rmax));
4171 }
4172 if (0 != lmax) {
4173 right->SetMin(PosIntDivUp(m, lmax));
4174 }
4175 }
4176}
4177
4178// Propagates set_max on left * right, left and right >= 0.
4179void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4180 DCHECK_GE(left->Min(), 0);
4181 DCHECK_GE(right->Min(), 0);
4182 const int64_t lmin = left->Min();
4183 const int64_t rmin = right->Min();
4184 if (m < CapProd(lmin, rmin)) {
4185 left->solver()->Fail();
4186 }
4187 if (m < CapProd(left->Max(), right->Max())) {
4188 if (0 != lmin) {
4189 right->SetMax(PosIntDivDown(m, lmin));
4190 }
4191 if (0 != rmin) {
4192 left->SetMax(PosIntDivDown(m, rmin));
4193 }
4194 // else do nothing: 0 is supporting any value from other expr.
4195 }
4196}
4197
4198// Propagates set_min on left * right, left >= 0, right across 0.
4199void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4200 DCHECK_GE(left->Min(), 0);
4201 DCHECK_GT(right->Max(), 0);
4202 DCHECK_LT(right->Min(), 0);
4203 const int64_t lmax = left->Max();
4204 const int64_t rmax = right->Max();
4205 if (m > CapProd(lmax, rmax)) {
4206 left->solver()->Fail();
4207 }
4208 if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4209 DCHECK_EQ(0, left->Min());
4210 DCHECK_LE(m, 0);
4211 } else {
4212 if (m > 0) { // We deduce right > 0.
4213 left->SetMin(PosIntDivUp(m, rmax));
4214 right->SetMin(PosIntDivUp(m, lmax));
4215 } else if (m == 0) {
4216 const int64_t lmin = left->Min();
4217 if (lmin > 0) {
4218 right->SetMin(0);
4219 }
4220 } else { // m < 0
4221 const int64_t lmin = left->Min();
4222 if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4223 right->SetMin(-PosIntDivDown(-m, lmin));
4224 }
4225 }
4226 }
4227}
4228
4229// Propagates set_min on left * right, left and right across 0.
4230void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4231 DCHECK_LT(left->Min(), 0);
4232 DCHECK_GT(left->Max(), 0);
4233 DCHECK_GT(right->Max(), 0);
4234 DCHECK_LT(right->Min(), 0);
4235 const int64_t lmin = left->Min();
4236 const int64_t lmax = left->Max();
4237 const int64_t rmin = right->Min();
4238 const int64_t rmax = right->Max();
4239 if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4240 left->solver()->Fail();
4241 }
4242 if (m >
4243 CapProd(lmin, rmin)) { // Must be positive section * positive section.
4244 left->SetMin(PosIntDivUp(m, rmax));
4245 right->SetMin(PosIntDivUp(m, lmax));
4246 } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4247 left->SetMax(CapOpp(PosIntDivUp(m, CapOpp(rmin))));
4248 right->SetMax(CapOpp(PosIntDivUp(m, CapOpp(lmin))));
4249 }
4250}
4251
4252void TimesSetMin(IntExpr* const left, IntExpr* const right,
4253 IntExpr* const minus_left, IntExpr* const minus_right,
4254 int64_t m) {
4255 if (left->Min() >= 0) {
4256 if (right->Min() >= 0) {
4257 SetPosPosMinExpr(left, right, m);
4258 } else if (right->Max() <= 0) {
4259 SetPosPosMaxExpr(left, minus_right, -m);
4260 } else { // right->Min() < 0 && right->Max() > 0
4261 SetPosGenMinExpr(left, right, m);
4262 }
4263 } else if (left->Max() <= 0) {
4264 if (right->Min() >= 0) {
4265 SetPosPosMaxExpr(right, minus_left, -m);
4266 } else if (right->Max() <= 0) {
4267 SetPosPosMinExpr(minus_left, minus_right, m);
4268 } else { // right->Min() < 0 && right->Max() > 0
4269 SetPosGenMinExpr(minus_left, minus_right, m);
4270 }
4271 } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4272 SetPosGenMinExpr(right, left, m);
4273 } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4274 SetPosGenMinExpr(minus_right, minus_left, m);
4275 } else { // left->Min() < 0 && left->Max() > 0 &&
4276 // right->Min() < 0 && right->Max() > 0
4277 SetGenGenMinExpr(left, right, m);
4278 }
4279}
4280
4281class TimesIntExpr : public BaseIntExpr {
4282 public:
4283 TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4284 : BaseIntExpr(s),
4285 left_(l),
4286 right_(r),
4287 minus_left_(s->MakeOpposite(left_)),
4288 minus_right_(s->MakeOpposite(right_)) {}
4289 ~TimesIntExpr() override {}
4290 int64_t Min() const override {
4291 const int64_t lmin = left_->Min();
4292 const int64_t lmax = left_->Max();
4293 const int64_t rmin = right_->Min();
4294 const int64_t rmax = right_->Max();
4295 return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4296 std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4297 }
4298 void SetMin(int64_t m) override;
4299 int64_t Max() const override {
4300 const int64_t lmin = left_->Min();
4301 const int64_t lmax = left_->Max();
4302 const int64_t rmin = right_->Min();
4303 const int64_t rmax = right_->Max();
4304 return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4305 std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4306 }
4307 void SetMax(int64_t m) override;
4308 bool Bound() const override;
4309 std::string name() const override {
4310 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4311 }
4312 std::string DebugString() const override {
4313 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4314 right_->DebugString());
4315 }
4316 void WhenRange(Demon* d) override {
4317 left_->WhenRange(d);
4318 right_->WhenRange(d);
4319 }
4320
4321 void Accept(ModelVisitor* const visitor) const override {
4322 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4323 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4324 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4325 right_);
4326 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4327 }
4328
4329 private:
4330 IntExpr* const left_;
4331 IntExpr* const right_;
4332 IntExpr* const minus_left_;
4333 IntExpr* const minus_right_;
4334};
4335
4336void TimesIntExpr::SetMin(int64_t m) {
4337 if (m != std::numeric_limits<int64_t>::min()) {
4338 TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4339 }
4340}
4341
4342void TimesIntExpr::SetMax(int64_t m) {
4343 if (m != std::numeric_limits<int64_t>::max()) {
4344 TimesSetMin(left_, minus_right_, minus_left_, right_, CapOpp(m));
4345 }
4346}
4347
4348bool TimesIntExpr::Bound() const {
4349 const bool left_bound = left_->Bound();
4350 const bool right_bound = right_->Bound();
4351 return ((left_bound && left_->Max() == 0) ||
4352 (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4353}
4354
4355// ----- TimesPosIntExpr -----
4356
4357class TimesPosIntExpr : public BaseIntExpr {
4358 public:
4359 TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4360 : BaseIntExpr(s), left_(l), right_(r) {}
4361 ~TimesPosIntExpr() override {}
4362 int64_t Min() const override { return (left_->Min() * right_->Min()); }
4363 void SetMin(int64_t m) override;
4364 int64_t Max() const override { return (left_->Max() * right_->Max()); }
4365 void SetMax(int64_t m) override;
4366 bool Bound() const override;
4367 std::string name() const override {
4368 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4369 }
4370 std::string DebugString() const override {
4371 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4372 right_->DebugString());
4373 }
4374 void WhenRange(Demon* d) override {
4375 left_->WhenRange(d);
4376 right_->WhenRange(d);
4377 }
4378
4379 void Accept(ModelVisitor* const visitor) const override {
4380 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4381 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4382 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4383 right_);
4384 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4385 }
4386
4387 private:
4388 IntExpr* const left_;
4389 IntExpr* const right_;
4390};
4391
4392void TimesPosIntExpr::SetMin(int64_t m) { SetPosPosMinExpr(left_, right_, m); }
4393
4394void TimesPosIntExpr::SetMax(int64_t m) { SetPosPosMaxExpr(left_, right_, m); }
4395
4396bool TimesPosIntExpr::Bound() const {
4397 return (left_->Max() == 0 || right_->Max() == 0 ||
4398 (left_->Bound() && right_->Bound()));
4399}
4400
4401// ----- SafeTimesPosIntExpr -----
4402
4403class SafeTimesPosIntExpr : public BaseIntExpr {
4404 public:
4405 SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4406 : BaseIntExpr(s), left_(l), right_(r) {}
4407 ~SafeTimesPosIntExpr() override {}
4408 int64_t Min() const override { return CapProd(left_->Min(), right_->Min()); }
4409 void SetMin(int64_t m) override {
4410 if (m != std::numeric_limits<int64_t>::min()) {
4411 SetPosPosMinExpr(left_, right_, m);
4412 }
4413 }
4414 int64_t Max() const override { return CapProd(left_->Max(), right_->Max()); }
4415 void SetMax(int64_t m) override {
4416 if (m != std::numeric_limits<int64_t>::max()) {
4417 SetPosPosMaxExpr(left_, right_, m);
4418 }
4419 }
4420 bool Bound() const override {
4421 return (left_->Max() == 0 || right_->Max() == 0 ||
4422 (left_->Bound() && right_->Bound()));
4423 }
4424 std::string name() const override {
4425 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4426 }
4427 std::string DebugString() const override {
4428 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4429 right_->DebugString());
4430 }
4431 void WhenRange(Demon* d) override {
4432 left_->WhenRange(d);
4433 right_->WhenRange(d);
4434 }
4435
4436 void Accept(ModelVisitor* const visitor) const override {
4437 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4438 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4439 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4440 right_);
4441 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4442 }
4443
4444 private:
4445 IntExpr* const left_;
4446 IntExpr* const right_;
4447};
4448
4449// ----- TimesBooleanPosIntExpr -----
4450
4451class TimesBooleanPosIntExpr : public BaseIntExpr {
4452 public:
4453 TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4454 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4455 ~TimesBooleanPosIntExpr() override {}
4456 int64_t Min() const override {
4457 return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4458 }
4459 void SetMin(int64_t m) override;
4460 int64_t Max() const override {
4461 return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4462 }
4463 void SetMax(int64_t m) override;
4464 void Range(int64_t* mi, int64_t* ma) override;
4465 void SetRange(int64_t mi, int64_t ma) override;
4466 bool Bound() const override;
4467 std::string name() const override {
4468 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4469 }
4470 std::string DebugString() const override {
4471 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4472 expr_->DebugString());
4473 }
4474 void WhenRange(Demon* d) override {
4475 boolvar_->WhenRange(d);
4476 expr_->WhenRange(d);
4477 }
4478
4479 void Accept(ModelVisitor* const visitor) const override {
4480 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4481 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4482 boolvar_);
4483 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4484 expr_);
4485 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4486 }
4487
4488 private:
4489 BooleanVar* const boolvar_;
4490 IntExpr* const expr_;
4491};
4492
4493void TimesBooleanPosIntExpr::SetMin(int64_t m) {
4494 if (m > 0) {
4495 boolvar_->SetValue(1);
4496 expr_->SetMin(m);
4497 }
4498}
4499
4500void TimesBooleanPosIntExpr::SetMax(int64_t m) {
4501 if (m < 0) {
4502 solver()->Fail();
4503 }
4504 if (m < expr_->Min()) {
4505 boolvar_->SetValue(0);
4506 }
4507 if (boolvar_->RawValue() == 1) {
4508 expr_->SetMax(m);
4509 }
4510}
4511
4512void TimesBooleanPosIntExpr::Range(int64_t* mi, int64_t* ma) {
4513 const int value = boolvar_->RawValue();
4514 if (value == 0) {
4515 *mi = 0;
4516 *ma = 0;
4517 } else if (value == 1) {
4518 expr_->Range(mi, ma);
4519 } else {
4520 *mi = 0;
4521 *ma = expr_->Max();
4522 }
4523}
4524
4525void TimesBooleanPosIntExpr::SetRange(int64_t mi, int64_t ma) {
4526 if (ma < 0 || mi > ma) {
4527 solver()->Fail();
4528 }
4529 if (mi > 0) {
4530 boolvar_->SetValue(1);
4531 expr_->SetMin(mi);
4532 }
4533 if (ma < expr_->Min()) {
4534 boolvar_->SetValue(0);
4535 }
4536 if (boolvar_->RawValue() == 1) {
4537 expr_->SetMax(ma);
4538 }
4539}
4540
4541bool TimesBooleanPosIntExpr::Bound() const {
4542 return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4543 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4544 expr_->Bound()));
4545}
4546
4547// ----- TimesBooleanIntExpr -----
4548
4549class TimesBooleanIntExpr : public BaseIntExpr {
4550 public:
4551 TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4552 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4553 ~TimesBooleanIntExpr() override {}
4554 int64_t Min() const override {
4555 switch (boolvar_->RawValue()) {
4556 case 0: {
4557 return 0LL;
4558 }
4559 case 1: {
4560 return expr_->Min();
4561 }
4562 default: {
4563 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4564 return std::min(int64_t{0}, expr_->Min());
4565 }
4566 }
4567 }
4568 void SetMin(int64_t m) override;
4569 int64_t Max() const override {
4570 switch (boolvar_->RawValue()) {
4571 case 0: {
4572 return 0LL;
4573 }
4574 case 1: {
4575 return expr_->Max();
4576 }
4577 default: {
4578 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4579 return std::max(int64_t{0}, expr_->Max());
4580 }
4581 }
4582 }
4583 void SetMax(int64_t m) override;
4584 void Range(int64_t* mi, int64_t* ma) override;
4585 void SetRange(int64_t mi, int64_t ma) override;
4586 bool Bound() const override;
4587 std::string name() const override {
4588 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4589 }
4590 std::string DebugString() const override {
4591 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4592 expr_->DebugString());
4593 }
4594 void WhenRange(Demon* d) override {
4595 boolvar_->WhenRange(d);
4596 expr_->WhenRange(d);
4597 }
4598
4599 void Accept(ModelVisitor* const visitor) const override {
4600 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4601 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4602 boolvar_);
4603 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4604 expr_);
4605 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4606 }
4607
4608 private:
4609 BooleanVar* const boolvar_;
4610 IntExpr* const expr_;
4611};
4612
4613void TimesBooleanIntExpr::SetMin(int64_t m) {
4614 switch (boolvar_->RawValue()) {
4615 case 0: {
4616 if (m > 0) {
4617 solver()->Fail();
4618 }
4619 break;
4620 }
4621 case 1: {
4622 expr_->SetMin(m);
4623 break;
4624 }
4625 default: {
4626 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4627 if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4628 boolvar_->SetValue(1);
4629 expr_->SetMin(m);
4630 } else if (m <= 0 && expr_->Max() < m) {
4631 boolvar_->SetValue(0);
4632 }
4633 }
4634 }
4635}
4636
4637void TimesBooleanIntExpr::SetMax(int64_t m) {
4638 switch (boolvar_->RawValue()) {
4639 case 0: {
4640 if (m < 0) {
4641 solver()->Fail();
4642 }
4643 break;
4644 }
4645 case 1: {
4646 expr_->SetMax(m);
4647 break;
4648 }
4649 default: {
4650 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4651 if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4652 boolvar_->SetValue(1);
4653 expr_->SetMax(m);
4654 } else if (m >= 0 && expr_->Min() > m) {
4655 boolvar_->SetValue(0);
4656 }
4657 }
4658 }
4659}
4660
4661void TimesBooleanIntExpr::Range(int64_t* mi, int64_t* ma) {
4662 switch (boolvar_->RawValue()) {
4663 case 0: {
4664 *mi = 0;
4665 *ma = 0;
4666 break;
4667 }
4668 case 1: {
4669 *mi = expr_->Min();
4670 *ma = expr_->Max();
4671 break;
4672 }
4673 default: {
4674 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4675 *mi = std::min(int64_t{0}, expr_->Min());
4676 *ma = std::max(int64_t{0}, expr_->Max());
4677 break;
4678 }
4679 }
4680}
4681
4682void TimesBooleanIntExpr::SetRange(int64_t mi, int64_t ma) {
4683 if (mi > ma) {
4684 solver()->Fail();
4685 }
4686 switch (boolvar_->RawValue()) {
4687 case 0: {
4688 if (mi > 0 || ma < 0) {
4689 solver()->Fail();
4690 }
4691 break;
4692 }
4693 case 1: {
4694 expr_->SetRange(mi, ma);
4695 break;
4696 }
4697 default: {
4698 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4699 if (mi > 0) {
4700 boolvar_->SetValue(1);
4701 expr_->SetMin(mi);
4702 } else if (mi == 0 && expr_->Max() < 0) {
4703 boolvar_->SetValue(0);
4704 }
4705 if (ma < 0) {
4706 boolvar_->SetValue(1);
4707 expr_->SetMax(ma);
4708 } else if (ma == 0 && expr_->Min() > 0) {
4709 boolvar_->SetValue(0);
4710 }
4711 break;
4712 }
4713 }
4714}
4715
4716bool TimesBooleanIntExpr::Bound() const {
4717 return (boolvar_->RawValue() == 0 ||
4718 (expr_->Bound() &&
4719 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4720 expr_->Max() == 0)));
4721}
4722
4723// ----- DivPosIntCstExpr -----
4724
4725class DivPosIntCstExpr : public BaseIntExpr {
4726 public:
4727 DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4728 : BaseIntExpr(s), expr_(e), value_(v) {
4729 CHECK_GE(v, 0);
4730 }
4731 ~DivPosIntCstExpr() override {}
4732
4733 int64_t Min() const override { return expr_->Min() / value_; }
4734
4735 void SetMin(int64_t m) override {
4736 if (m > 0) {
4737 expr_->SetMin(m * value_);
4738 } else {
4739 expr_->SetMin((m - 1) * value_ + 1);
4740 }
4741 }
4742 int64_t Max() const override { return expr_->Max() / value_; }
4743
4744 void SetMax(int64_t m) override {
4745 if (m >= 0) {
4746 expr_->SetMax((m + 1) * value_ - 1);
4747 } else {
4748 expr_->SetMax(m * value_);
4749 }
4750 }
4751
4752 std::string name() const override {
4753 return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4754 }
4755
4756 std::string DebugString() const override {
4757 return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4758 }
4759
4760 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4761
4762 void Accept(ModelVisitor* const visitor) const override {
4763 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4764 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4765 expr_);
4766 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4767 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4768 }
4769
4770 private:
4771 IntExpr* const expr_;
4772 const int64_t value_;
4773};
4774
4775// DivPosIntExpr
4776
4777class DivPosIntExpr : public BaseIntExpr {
4778 public:
4779 DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4780 : BaseIntExpr(s),
4781 num_(num),
4782 denom_(denom),
4783 opp_num_(s->MakeOpposite(num)) {}
4784
4785 ~DivPosIntExpr() override {}
4786
4787 int64_t Min() const override {
4788 return num_->Min() >= 0
4789 ? num_->Min() / denom_->Max()
4790 : (denom_->Min() == 0 ? num_->Min()
4791 : num_->Min() / denom_->Min());
4792 }
4793
4794 int64_t Max() const override {
4795 return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4796 : num_->Max() / denom_->Min())
4797 : num_->Max() / denom_->Max();
4798 }
4799
4800 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4801 num->SetMin(m * denom->Min());
4802 denom->SetMax(num->Max() / m);
4803 }
4804
4805 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
4806 num->SetMax((m + 1) * denom->Max() - 1);
4807 denom->SetMin(num->Min() / (m + 1) + 1);
4808 }
4809
4810 void SetMin(int64_t m) override {
4811 if (m > 0) {
4812 SetPosMin(num_, denom_, m);
4813 } else {
4814 SetPosMax(opp_num_, denom_, -m);
4815 }
4816 }
4817
4818 void SetMax(int64_t m) override {
4819 if (m >= 0) {
4820 SetPosMax(num_, denom_, m);
4821 } else {
4822 SetPosMin(opp_num_, denom_, -m);
4823 }
4824 }
4825
4826 std::string name() const override {
4827 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4828 }
4829 std::string DebugString() const override {
4830 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4831 denom_->DebugString());
4832 }
4833 void WhenRange(Demon* d) override {
4834 num_->WhenRange(d);
4835 denom_->WhenRange(d);
4836 }
4837
4838 void Accept(ModelVisitor* const visitor) const override {
4839 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4840 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4841 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4842 denom_);
4843 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4844 }
4845
4846 private:
4847 IntExpr* const num_;
4848 IntExpr* const denom_;
4849 IntExpr* const opp_num_;
4850};
4851
4852class DivPosPosIntExpr : public BaseIntExpr {
4853 public:
4854 DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4855 : BaseIntExpr(s), num_(num), denom_(denom) {}
4856
4857 ~DivPosPosIntExpr() override {}
4858
4859 int64_t Min() const override {
4860 if (denom_->Max() == 0) {
4861 solver()->Fail();
4862 }
4863 return num_->Min() / denom_->Max();
4864 }
4865
4866 int64_t Max() const override {
4867 if (denom_->Min() == 0) {
4868 return num_->Max();
4869 } else {
4870 return num_->Max() / denom_->Min();
4871 }
4872 }
4873
4874 void SetMin(int64_t m) override {
4875 if (m > 0) {
4876 num_->SetMin(m * denom_->Min());
4877 denom_->SetMax(num_->Max() / m);
4878 }
4879 }
4880
4881 void SetMax(int64_t m) override {
4882 if (m >= 0) {
4883 num_->SetMax((m + 1) * denom_->Max() - 1);
4884 denom_->SetMin(num_->Min() / (m + 1) + 1);
4885 } else {
4886 solver()->Fail();
4887 }
4888 }
4889
4890 std::string name() const override {
4891 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4892 }
4893
4894 std::string DebugString() const override {
4895 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4896 denom_->DebugString());
4897 }
4898
4899 void WhenRange(Demon* d) override {
4900 num_->WhenRange(d);
4901 denom_->WhenRange(d);
4902 }
4903
4904 void Accept(ModelVisitor* const visitor) const override {
4905 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4906 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4907 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4908 denom_);
4909 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4910 }
4911
4912 private:
4913 IntExpr* const num_;
4914 IntExpr* const denom_;
4915};
4916
4917// DivIntExpr
4918
4919class DivIntExpr : public BaseIntExpr {
4920 public:
4921 DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4922 : BaseIntExpr(s),
4923 num_(num),
4924 denom_(denom),
4925 opp_num_(s->MakeOpposite(num)) {}
4926
4927 ~DivIntExpr() override {}
4928
4929 int64_t Min() const override {
4930 const int64_t num_min = num_->Min();
4931 const int64_t num_max = num_->Max();
4932 const int64_t denom_min = denom_->Min();
4933 const int64_t denom_max = denom_->Max();
4934
4935 if (denom_min == 0 && denom_max == 0) {
4936 return std::numeric_limits<int64_t>::max(); // TODO(user): Check this
4937 // convention.
4938 }
4939
4940 if (denom_min >= 0) { // Denominator strictly positive.
4941 DCHECK_GT(denom_max, 0);
4942 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4943 return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4944 } else if (denom_max <= 0) { // Denominator strictly negative.
4945 DCHECK_LT(denom_min, 0);
4946 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4947 return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4948 } else { // Denominator across 0.
4949 return std::min(num_min, -num_max);
4950 }
4951 }
4952
4953 int64_t Max() const override {
4954 const int64_t num_min = num_->Min();
4955 const int64_t num_max = num_->Max();
4956 const int64_t denom_min = denom_->Min();
4957 const int64_t denom_max = denom_->Max();
4958
4959 if (denom_min == 0 && denom_max == 0) {
4960 return std::numeric_limits<int64_t>::min(); // TODO(user): Check this
4961 // convention.
4962 }
4963
4964 if (denom_min >= 0) { // Denominator strictly positive.
4965 DCHECK_GT(denom_max, 0);
4966 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4967 return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4968 } else if (denom_max <= 0) { // Denominator strictly negative.
4969 DCHECK_LT(denom_min, 0);
4970 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4971 return num_min >= 0 ? num_min / denom_min
4972 : -num_min / -adjusted_denom_max;
4973 } else { // Denominator across 0.
4974 return std::max(num_max, -num_min);
4975 }
4976 }
4977
4978 void AdjustDenominator() {
4979 if (denom_->Min() == 0) {
4980 denom_->SetMin(1);
4981 } else if (denom_->Max() == 0) {
4982 denom_->SetMax(-1);
4983 }
4984 }
4985
4986 // m > 0.
4987 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4988 DCHECK_GT(m, 0);
4989 const int64_t num_min = num->Min();
4990 const int64_t num_max = num->Max();
4991 const int64_t denom_min = denom->Min();
4992 const int64_t denom_max = denom->Max();
4993 DCHECK_NE(denom_min, 0);
4994 DCHECK_NE(denom_max, 0);
4995 if (denom_min > 0) { // Denominator strictly positive.
4996 num->SetMin(m * denom_min);
4997 denom->SetMax(num_max / m);
4998 } else if (denom_max < 0) { // Denominator strictly negative.
4999 num->SetMax(m * denom_max);
5000 denom->SetMin(num_min / m);
5001 } else { // Denominator across 0.
5002 if (num_min >= 0) {
5003 num->SetMin(m);
5004 denom->SetRange(1, num_max / m);
5005 } else if (num_max <= 0) {
5006 num->SetMax(-m);
5007 denom->SetRange(num_min / m, -1);
5008 } else {
5009 if (m > -num_min) { // Denominator is forced positive.
5010 num->SetMin(m);
5011 denom->SetRange(1, num_max / m);
5012 } else if (m > num_max) { // Denominator is forced negative.
5013 num->SetMax(-m);
5014 denom->SetRange(num_min / m, -1);
5015 } else {
5016 denom->SetRange(num_min / m, num_max / m);
5017 }
5018 }
5019 }
5020 }
5021
5022 // m >= 0.
5023 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
5024 DCHECK_GE(m, 0);
5025 const int64_t num_min = num->Min();
5026 const int64_t num_max = num->Max();
5027 const int64_t denom_min = denom->Min();
5028 const int64_t denom_max = denom->Max();
5029 DCHECK_NE(denom_min, 0);
5030 DCHECK_NE(denom_max, 0);
5031 if (denom_min > 0) { // Denominator strictly positive.
5032 num->SetMax((m + 1) * denom_max - 1);
5033 denom->SetMin((num_min / (m + 1)) + 1);
5034 } else if (denom_max < 0) {
5035 num->SetMin((m + 1) * denom_min + 1);
5036 denom->SetMax(num_max / (m + 1) - 1);
5037 } else if (num_min > (m + 1) * denom_max - 1) {
5038 denom->SetMax(-1);
5039 } else if (num_max < (m + 1) * denom_min + 1) {
5040 denom->SetMin(1);
5041 }
5042 }
5043
5044 void SetMin(int64_t m) override {
5045 AdjustDenominator();
5046 if (m > 0) {
5047 SetPosMin(num_, denom_, m);
5048 } else {
5049 SetPosMax(opp_num_, denom_, -m);
5050 }
5051 }
5052
5053 void SetMax(int64_t m) override {
5054 AdjustDenominator();
5055 if (m >= 0) {
5056 SetPosMax(num_, denom_, m);
5057 } else {
5058 SetPosMin(opp_num_, denom_, -m);
5059 }
5060 }
5061
5062 std::string name() const override {
5063 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5064 }
5065 std::string DebugString() const override {
5066 return absl::StrFormat("(%s div %s)", num_->DebugString(),
5067 denom_->DebugString());
5068 }
5069 void WhenRange(Demon* d) override {
5070 num_->WhenRange(d);
5071 denom_->WhenRange(d);
5072 }
5073
5074 void Accept(ModelVisitor* const visitor) const override {
5075 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5076 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5077 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5078 denom_);
5079 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5080 }
5081
5082 private:
5083 IntExpr* const num_;
5084 IntExpr* const denom_;
5085 IntExpr* const opp_num_;
5086};
5087
5088// ----- IntAbs And IntAbsConstraint ------
5089
5090class IntAbsConstraint : public CastConstraint {
5091 public:
5092 IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5093 : CastConstraint(s, target), sub_(sub) {}
5094
5095 ~IntAbsConstraint() override {}
5096
5097 void Post() override {
5098 Demon* const sub_demon = MakeConstraintDemon0(
5099 solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5100 sub_->WhenRange(sub_demon);
5101 Demon* const target_demon = MakeConstraintDemon0(
5102 solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5103 target_var_->WhenRange(target_demon);
5104 }
5105
5106 void InitialPropagate() override {
5107 PropagateSub();
5108 PropagateTarget();
5109 }
5110
5111 void PropagateSub() {
5112 const int64_t smin = sub_->Min();
5113 const int64_t smax = sub_->Max();
5114 if (smax <= 0) {
5115 target_var_->SetRange(-smax, -smin);
5116 } else if (smin >= 0) {
5117 target_var_->SetRange(smin, smax);
5118 } else {
5119 target_var_->SetRange(0, std::max(-smin, smax));
5120 }
5121 }
5122
5123 void PropagateTarget() {
5124 const int64_t target_max = target_var_->Max();
5125 sub_->SetRange(-target_max, target_max);
5126 const int64_t target_min = target_var_->Min();
5127 if (target_min > 0) {
5128 if (sub_->Min() > -target_min) {
5129 sub_->SetMin(target_min);
5130 } else if (sub_->Max() < target_min) {
5131 sub_->SetMax(-target_min);
5132 }
5133 }
5134 }
5135
5136 std::string DebugString() const override {
5137 return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5138 target_var_->DebugString());
5139 }
5140
5141 void Accept(ModelVisitor* const visitor) const override {
5142 visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5143 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5144 sub_);
5145 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5146 target_var_);
5147 visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5148 }
5149
5150 private:
5151 IntVar* const sub_;
5152};
5153
5154class IntAbs : public BaseIntExpr {
5155 public:
5156 IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5157
5158 ~IntAbs() override {}
5159
5160 int64_t Min() const override {
5161 int64_t emin = 0;
5162 int64_t emax = 0;
5163 expr_->Range(&emin, &emax);
5164 if (emin >= 0) {
5165 return emin;
5166 }
5167 if (emax <= 0) {
5168 return -emax;
5169 }
5170 return 0;
5171 }
5172
5173 void SetMin(int64_t m) override {
5174 if (m > 0) {
5175 int64_t emin = 0;
5176 int64_t emax = 0;
5177 expr_->Range(&emin, &emax);
5178 if (emin > -m) {
5179 expr_->SetMin(m);
5180 } else if (emax < m) {
5181 expr_->SetMax(-m);
5182 }
5183 }
5184 }
5185
5186 int64_t Max() const override {
5187 int64_t emin = 0;
5188 int64_t emax = 0;
5189 expr_->Range(&emin, &emax);
5190 return std::max(-emin, emax);
5191 }
5192
5193 void SetMax(int64_t m) override { expr_->SetRange(-m, m); }
5194
5195 void SetRange(int64_t mi, int64_t ma) override {
5196 expr_->SetRange(-ma, ma);
5197 if (mi > 0) {
5198 int64_t emin = 0;
5199 int64_t emax = 0;
5200 expr_->Range(&emin, &emax);
5201 if (emin > -mi) {
5202 expr_->SetMin(mi);
5203 } else if (emax < mi) {
5204 expr_->SetMax(-mi);
5205 }
5206 }
5207 }
5208
5209 void Range(int64_t* mi, int64_t* ma) override {
5210 int64_t emin = 0;
5211 int64_t emax = 0;
5212 expr_->Range(&emin, &emax);
5213 if (emin >= 0) {
5214 *mi = emin;
5215 *ma = emax;
5216 } else if (emax <= 0) {
5217 *mi = -emax;
5218 *ma = -emin;
5219 } else {
5220 *mi = 0;
5221 *ma = std::max(-emin, emax);
5222 }
5223 }
5224
5225 bool Bound() const override { return expr_->Bound(); }
5226
5227 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5228
5229 std::string name() const override {
5230 return absl::StrFormat("IntAbs(%s)", expr_->name());
5231 }
5232
5233 std::string DebugString() const override {
5234 return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5235 }
5236
5237 void Accept(ModelVisitor* const visitor) const override {
5238 visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5239 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5240 expr_);
5241 visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5242 }
5243
5244 IntVar* CastToVar() override {
5245 int64_t min_value = 0;
5246 int64_t max_value = 0;
5247 Range(&min_value, &max_value);
5248 Solver* const s = solver();
5249 const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5250 IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5251 CastConstraint* const ct =
5252 s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5253 s->AddCastConstraint(ct, target, this);
5254 return target;
5255 }
5256
5257 private:
5258 IntExpr* const expr_;
5259};
5260
5261// ----- Square -----
5262
5263// TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5264class IntSquare : public BaseIntExpr {
5265 public:
5266 IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5267 ~IntSquare() override {}
5268
5269 int64_t Min() const override {
5270 const int64_t emin = expr_->Min();
5271 if (emin >= 0) {
5272 return emin >= std::numeric_limits<int32_t>::max()
5273 ? std::numeric_limits<int64_t>::max()
5274 : emin * emin;
5275 }
5276 const int64_t emax = expr_->Max();
5277 if (emax < 0) {
5278 return emax <= -std::numeric_limits<int32_t>::max()
5279 ? std::numeric_limits<int64_t>::max()
5280 : emax * emax;
5281 }
5282 return 0LL;
5283 }
5284 void SetMin(int64_t m) override {
5285 if (m <= 0) {
5286 return;
5287 }
5288 // TODO(user): What happens if m is kint64max?
5289 const int64_t emin = expr_->Min();
5290 const int64_t emax = expr_->Max();
5291 const int64_t root =
5292 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5293 if (emin >= 0) {
5294 expr_->SetMin(root);
5295 } else if (emax <= 0) {
5296 expr_->SetMax(-root);
5297 } else if (expr_->IsVar()) {
5298 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5299 }
5300 }
5301 int64_t Max() const override {
5302 const int64_t emax = expr_->Max();
5303 const int64_t emin = expr_->Min();
5304 if (emax >= std::numeric_limits<int32_t>::max() ||
5305 emin <= -std::numeric_limits<int32_t>::max()) {
5306 return std::numeric_limits<int64_t>::max();
5307 }
5308 return std::max(emin * emin, emax * emax);
5309 }
5310 void SetMax(int64_t m) override {
5311 if (m < 0) {
5312 solver()->Fail();
5313 }
5314 if (m == std::numeric_limits<int64_t>::max()) {
5315 return;
5316 }
5317 const int64_t root =
5318 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5319 expr_->SetRange(-root, root);
5320 }
5321 bool Bound() const override { return expr_->Bound(); }
5322 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5323 std::string name() const override {
5324 return absl::StrFormat("IntSquare(%s)", expr_->name());
5325 }
5326 std::string DebugString() const override {
5327 return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5328 }
5329
5330 void Accept(ModelVisitor* const visitor) const override {
5331 visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5332 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5333 expr_);
5334 visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5335 }
5336
5337 IntExpr* expr() const { return expr_; }
5338
5339 protected:
5340 IntExpr* const expr_;
5341};
5342
5343class PosIntSquare : public IntSquare {
5344 public:
5345 PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5346 ~PosIntSquare() override {}
5347
5348 int64_t Min() const override {
5349 const int64_t emin = expr_->Min();
5350 return emin >= std::numeric_limits<int32_t>::max()
5351 ? std::numeric_limits<int64_t>::max()
5352 : emin * emin;
5353 }
5354 void SetMin(int64_t m) override {
5355 if (m <= 0) {
5356 return;
5357 }
5358 int64_t root = static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5359 if (CapProd(root, root) < m) {
5360 root++;
5361 }
5362 expr_->SetMin(root);
5363 }
5364 int64_t Max() const override {
5365 const int64_t emax = expr_->Max();
5366 return emax >= std::numeric_limits<int32_t>::max()
5367 ? std::numeric_limits<int64_t>::max()
5368 : emax * emax;
5369 }
5370 void SetMax(int64_t m) override {
5371 if (m < 0) {
5372 solver()->Fail();
5373 }
5374 if (m == std::numeric_limits<int64_t>::max()) {
5375 return;
5376 }
5377 int64_t root = static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5378 if (CapProd(root, root) > m) {
5379 root--;
5380 }
5381
5382 expr_->SetMax(root);
5383 }
5384};
5385
5386// ----- EvenPower -----
5387
5388int64_t IntPower(int64_t value, int64_t power) {
5389 int64_t result = value;
5390 // TODO(user): Speed that up.
5391 for (int i = 1; i < power; ++i) {
5392 result *= value;
5393 }
5394 return result;
5395}
5396
5397int64_t OverflowLimit(int64_t power) {
5398 return static_cast<int64_t>(floor(exp(
5399 log(static_cast<double>(std::numeric_limits<int64_t>::max())) / power)));
5400}
5401
5402class BasePower : public BaseIntExpr {
5403 public:
5404 BasePower(Solver* const s, IntExpr* const e, int64_t n)
5405 : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5406 CHECK_GT(n, 0);
5407 }
5408
5409 ~BasePower() override {}
5410
5411 bool Bound() const override { return expr_->Bound(); }
5412
5413 IntExpr* expr() const { return expr_; }
5414
5415 int64_t exponant() const { return pow_; }
5416
5417 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5418
5419 std::string name() const override {
5420 return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5421 }
5422
5423 std::string DebugString() const override {
5424 return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5425 }
5426
5427 void Accept(ModelVisitor* const visitor) const override {
5428 visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5429 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5430 expr_);
5431 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5432 visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5433 }
5434
5435 protected:
5436 int64_t Pown(int64_t value) const {
5437 if (value >= limit_) {
5438 return std::numeric_limits<int64_t>::max();
5439 }
5440 if (value <= -limit_) {
5441 if (pow_ % 2 == 0) {
5442 return std::numeric_limits<int64_t>::max();
5443 } else {
5444 return std::numeric_limits<int64_t>::min();
5445 }
5446 }
5447 return IntPower(value, pow_);
5448 }
5449
5450 int64_t SqrnDown(int64_t value) const {
5451 if (value == std::numeric_limits<int64_t>::min()) {
5452 return std::numeric_limits<int64_t>::min();
5453 }
5454 if (value == std::numeric_limits<int64_t>::max()) {
5455 return std::numeric_limits<int64_t>::max();
5456 }
5457 int64_t res = 0;
5458 const double d_value = static_cast<double>(value);
5459 if (value >= 0) {
5460 const double sq = exp(log(d_value) / pow_);
5461 res = static_cast<int64_t>(floor(sq));
5462 } else {
5463 CHECK_EQ(1, pow_ % 2);
5464 const double sq = exp(log(-d_value) / pow_);
5465 res = -static_cast<int64_t>(ceil(sq));
5466 }
5467 const int64_t pow_res = Pown(res + 1);
5468 if (pow_res <= value) {
5469 return res + 1;
5470 } else {
5471 return res;
5472 }
5473 }
5474
5475 int64_t SqrnUp(int64_t value) const {
5476 if (value == std::numeric_limits<int64_t>::min()) {
5477 return std::numeric_limits<int64_t>::min();
5478 }
5479 if (value == std::numeric_limits<int64_t>::max()) {
5480 return std::numeric_limits<int64_t>::max();
5481 }
5482 int64_t res = 0;
5483 const double d_value = static_cast<double>(value);
5484 if (value >= 0) {
5485 const double sq = exp(log(d_value) / pow_);
5486 res = static_cast<int64_t>(ceil(sq));
5487 } else {
5488 CHECK_EQ(1, pow_ % 2);
5489 const double sq = exp(log(-d_value) / pow_);
5490 res = -static_cast<int64_t>(floor(sq));
5491 }
5492 const int64_t pow_res = Pown(res - 1);
5493 if (pow_res >= value) {
5494 return res - 1;
5495 } else {
5496 return res;
5497 }
5498 }
5499
5500 IntExpr* const expr_;
5501 const int64_t pow_;
5502 const int64_t limit_;
5503};
5504
5505class IntEvenPower : public BasePower {
5506 public:
5507 IntEvenPower(Solver* const s, IntExpr* const e, int64_t n)
5508 : BasePower(s, e, n) {
5509 CHECK_EQ(0, n % 2);
5510 }
5511
5512 ~IntEvenPower() override {}
5513
5514 int64_t Min() const override {
5515 int64_t emin = 0;
5516 int64_t emax = 0;
5517 expr_->Range(&emin, &emax);
5518 if (emin >= 0) {
5519 return Pown(emin);
5520 }
5521 if (emax < 0) {
5522 return Pown(emax);
5523 }
5524 return 0LL;
5525 }
5526 void SetMin(int64_t m) override {
5527 if (m <= 0) {
5528 return;
5529 }
5530 int64_t emin = 0;
5531 int64_t emax = 0;
5532 expr_->Range(&emin, &emax);
5533 const int64_t root = SqrnUp(m);
5534 if (emin > -root) {
5535 expr_->SetMin(root);
5536 } else if (emax < root) {
5537 expr_->SetMax(-root);
5538 } else if (expr_->IsVar()) {
5539 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5540 }
5541 }
5542
5543 int64_t Max() const override {
5544 return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5545 }
5546
5547 void SetMax(int64_t m) override {
5548 if (m < 0) {
5549 solver()->Fail();
5550 }
5551 if (m == std::numeric_limits<int64_t>::max()) {
5552 return;
5553 }
5554 const int64_t root = SqrnDown(m);
5555 expr_->SetRange(-root, root);
5556 }
5557};
5558
5559class PosIntEvenPower : public BasePower {
5560 public:
5561 PosIntEvenPower(Solver* const s, IntExpr* const e, int64_t pow)
5562 : BasePower(s, e, pow) {
5563 CHECK_EQ(0, pow % 2);
5564 }
5565
5566 ~PosIntEvenPower() override {}
5567
5568 int64_t Min() const override { return Pown(expr_->Min()); }
5569
5570 void SetMin(int64_t m) override {
5571 if (m <= 0) {
5572 return;
5573 }
5574 expr_->SetMin(SqrnUp(m));
5575 }
5576 int64_t Max() const override { return Pown(expr_->Max()); }
5577
5578 void SetMax(int64_t m) override {
5579 if (m < 0) {
5580 solver()->Fail();
5581 }
5582 if (m == std::numeric_limits<int64_t>::max()) {
5583 return;
5584 }
5585 expr_->SetMax(SqrnDown(m));
5586 }
5587};
5588
5589class IntOddPower : public BasePower {
5590 public:
5591 IntOddPower(Solver* const s, IntExpr* const e, int64_t n)
5592 : BasePower(s, e, n) {
5593 CHECK_EQ(1, n % 2);
5594 }
5595
5596 ~IntOddPower() override {}
5597
5598 int64_t Min() const override { return Pown(expr_->Min()); }
5599
5600 void SetMin(int64_t m) override { expr_->SetMin(SqrnUp(m)); }
5601
5602 int64_t Max() const override { return Pown(expr_->Max()); }
5603
5604 void SetMax(int64_t m) override { expr_->SetMax(SqrnDown(m)); }
5605};
5606
5607// ----- Min(expr, expr) -----
5608
5609class MinIntExpr : public BaseIntExpr {
5610 public:
5611 MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5612 : BaseIntExpr(s), left_(l), right_(r) {}
5613 ~MinIntExpr() override {}
5614 int64_t Min() const override {
5615 const int64_t lmin = left_->Min();
5616 const int64_t rmin = right_->Min();
5617 return std::min(lmin, rmin);
5618 }
5619 void SetMin(int64_t m) override {
5620 left_->SetMin(m);
5621 right_->SetMin(m);
5622 }
5623 int64_t Max() const override {
5624 const int64_t lmax = left_->Max();
5625 const int64_t rmax = right_->Max();
5626 return std::min(lmax, rmax);
5627 }
5628 void SetMax(int64_t m) override {
5629 if (left_->Min() > m) {
5630 right_->SetMax(m);
5631 }
5632 if (right_->Min() > m) {
5633 left_->SetMax(m);
5634 }
5635 }
5636 std::string name() const override {
5637 return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5638 }
5639 std::string DebugString() const override {
5640 return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5641 right_->DebugString());
5642 }
5643 void WhenRange(Demon* d) override {
5644 left_->WhenRange(d);
5645 right_->WhenRange(d);
5646 }
5647
5648 void Accept(ModelVisitor* const visitor) const override {
5649 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5650 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5651 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5652 right_);
5653 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5654 }
5655
5656 private:
5657 IntExpr* const left_;
5658 IntExpr* const right_;
5659};
5660
5661// ----- Min(expr, constant) -----
5662
5663class MinCstIntExpr : public BaseIntExpr {
5664 public:
5665 MinCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5666 : BaseIntExpr(s), expr_(e), value_(v) {}
5667
5668 ~MinCstIntExpr() override {}
5669
5670 int64_t Min() const override { return std::min(expr_->Min(), value_); }
5671
5672 void SetMin(int64_t m) override {
5673 if (m > value_) {
5674 solver()->Fail();
5675 }
5676 expr_->SetMin(m);
5677 }
5678
5679 int64_t Max() const override { return std::min(expr_->Max(), value_); }
5680
5681 void SetMax(int64_t m) override {
5682 if (value_ > m) {
5683 expr_->SetMax(m);
5684 }
5685 }
5686
5687 bool Bound() const override {
5688 return (expr_->Bound() || expr_->Min() >= value_);
5689 }
5690
5691 std::string name() const override {
5692 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5693 }
5694
5695 std::string DebugString() const override {
5696 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5697 value_);
5698 }
5699
5700 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5701
5702 void Accept(ModelVisitor* const visitor) const override {
5703 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5704 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5705 expr_);
5706 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5707 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5708 }
5709
5710 private:
5711 IntExpr* const expr_;
5712 const int64_t value_;
5713};
5714
5715// ----- Max(expr, expr) -----
5716
5717class MaxIntExpr : public BaseIntExpr {
5718 public:
5719 MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5720 : BaseIntExpr(s), left_(l), right_(r) {}
5721
5722 ~MaxIntExpr() override {}
5723
5724 int64_t Min() const override { return std::max(left_->Min(), right_->Min()); }
5725
5726 void SetMin(int64_t m) override {
5727 if (left_->Max() < m) {
5728 right_->SetMin(m);
5729 } else {
5730 if (right_->Max() < m) {
5731 left_->SetMin(m);
5732 }
5733 }
5734 }
5735
5736 int64_t Max() const override { return std::max(left_->Max(), right_->Max()); }
5737
5738 void SetMax(int64_t m) override {
5739 left_->SetMax(m);
5740 right_->SetMax(m);
5741 }
5742
5743 std::string name() const override {
5744 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5745 }
5746
5747 std::string DebugString() const override {
5748 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5749 right_->DebugString());
5750 }
5751
5752 void WhenRange(Demon* d) override {
5753 left_->WhenRange(d);
5754 right_->WhenRange(d);
5755 }
5756
5757 void Accept(ModelVisitor* const visitor) const override {
5758 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5759 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5760 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5761 right_);
5762 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5763 }
5764
5765 private:
5766 IntExpr* const left_;
5767 IntExpr* const right_;
5768};
5769
5770// ----- Max(expr, constant) -----
5771
5772class MaxCstIntExpr : public BaseIntExpr {
5773 public:
5774 MaxCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5775 : BaseIntExpr(s), expr_(e), value_(v) {}
5776
5777 ~MaxCstIntExpr() override {}
5778
5779 int64_t Min() const override { return std::max(expr_->Min(), value_); }
5780
5781 void SetMin(int64_t m) override {
5782 if (value_ < m) {
5783 expr_->SetMin(m);
5784 }
5785 }
5786
5787 int64_t Max() const override { return std::max(expr_->Max(), value_); }
5788
5789 void SetMax(int64_t m) override {
5790 if (m < value_) {
5791 solver()->Fail();
5792 }
5793 expr_->SetMax(m);
5794 }
5795
5796 bool Bound() const override {
5797 return (expr_->Bound() || expr_->Max() <= value_);
5798 }
5799
5800 std::string name() const override {
5801 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5802 }
5803
5804 std::string DebugString() const override {
5805 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5806 value_);
5807 }
5808
5809 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5810
5811 void Accept(ModelVisitor* const visitor) const override {
5812 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5813 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5814 expr_);
5815 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5816 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5817 }
5818
5819 private:
5820 IntExpr* const expr_;
5821 const int64_t value_;
5822};
5823
5824// ----- Convex Piecewise -----
5825
5826// This class is a very simple convex piecewise linear function. The
5827// argument of the function is the expression. Between early_date and
5828// late_date, the value of the function is 0. Before early date, it
5829// is affine and the cost is early_cost * (early_date - x). After
5830// late_date, the cost is late_cost * (x - late_date).
5831
5832class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5833 public:
5834 SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64_t ec,
5835 int64_t ed, int64_t ld, int64_t lc)
5836 : BaseIntExpr(s),
5837 expr_(e),
5838 early_cost_(ec),
5839 early_date_(ec == 0 ? std::numeric_limits<int64_t>::min() : ed),
5840 late_date_(lc == 0 ? std::numeric_limits<int64_t>::max() : ld),
5841 late_cost_(lc) {
5842 DCHECK_GE(ec, int64_t{0});
5843 DCHECK_GE(lc, int64_t{0});
5844 DCHECK_GE(ld, ed);
5845
5846 // If the penalty is 0, we can push the "comfort zone or zone of no cost
5847 // towards infinity.
5848 }
5849
5850 ~SimpleConvexPiecewiseExpr() override {}
5851
5852 int64_t Min() const override {
5853 const int64_t vmin = expr_->Min();
5854 const int64_t vmax = expr_->Max();
5855 if (vmin >= late_date_) {
5856 return (vmin - late_date_) * late_cost_;
5857 } else if (vmax <= early_date_) {
5858 return (early_date_ - vmax) * early_cost_;
5859 } else {
5860 return 0LL;
5861 }
5862 }
5863
5864 void SetMin(int64_t m) override {
5865 if (m <= 0) {
5866 return;
5867 }
5868 int64_t vmin = 0;
5869 int64_t vmax = 0;
5870 expr_->Range(&vmin, &vmax);
5871
5872 const int64_t rb =
5873 (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5874 const int64_t lb =
5875 (early_cost_ == 0 ? vmin
5876 : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5877
5878 if (expr_->IsVar()) {
5879 expr_->Var()->RemoveInterval(lb, rb);
5880 }
5881 }
5882
5883 int64_t Max() const override {
5884 const int64_t vmin = expr_->Min();
5885 const int64_t vmax = expr_->Max();
5886 const int64_t mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5887 const int64_t ml =
5888 vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5889 return std::max(mr, ml);
5890 }
5891
5892 void SetMax(int64_t m) override {
5893 if (m < 0) {
5894 solver()->Fail();
5895 }
5896 if (late_cost_ != 0LL) {
5897 const int64_t rb = late_date_ + PosIntDivDown(m, late_cost_);
5898 if (early_cost_ != 0LL) {
5899 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5900 expr_->SetRange(lb, rb);
5901 } else {
5902 expr_->SetMax(rb);
5903 }
5904 } else {
5905 if (early_cost_ != 0LL) {
5906 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5907 expr_->SetMin(lb);
5908 }
5909 }
5910 }
5911
5912 std::string name() const override {
5913 return absl::StrFormat(
5914 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5915 expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5916 }
5917
5918 std::string DebugString() const override {
5919 return absl::StrFormat(
5920 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5921 expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5922 }
5923
5924 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5925
5926 void Accept(ModelVisitor* const visitor) const override {
5927 visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5928 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5929 expr_);
5930 visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5931 early_cost_);
5932 visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5933 early_date_);
5934 visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5935 visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5936 visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5937 }
5938
5939 private:
5940 IntExpr* const expr_;
5941 const int64_t early_cost_;
5942 const int64_t early_date_;
5943 const int64_t late_date_;
5944 const int64_t late_cost_;
5945};
5946
5947// ----- Semi Continuous -----
5948
5949class SemiContinuousExpr : public BaseIntExpr {
5950 public:
5951 SemiContinuousExpr(Solver* const s, IntExpr* const e, int64_t fixed_charge,
5952 int64_t step)
5953 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5954 DCHECK_GE(fixed_charge, int64_t{0});
5955 DCHECK_GT(step, int64_t{0});
5956 }
5957
5958 ~SemiContinuousExpr() override {}
5959
5960 int64_t Value(int64_t x) const {
5961 if (x <= 0) {
5962 return 0;
5963 } else {
5964 return CapAdd(fixed_charge_, CapProd(x, step_));
5965 }
5966 }
5967
5968 int64_t Min() const override { return Value(expr_->Min()); }
5969
5970 void SetMin(int64_t m) override {
5971 if (m >= CapAdd(fixed_charge_, step_)) {
5972 const int64_t y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5973 expr_->SetMin(y);
5974 } else if (m > 0) {
5975 expr_->SetMin(1);
5976 }
5977 }
5978
5979 int64_t Max() const override { return Value(expr_->Max()); }
5980
5981 void SetMax(int64_t m) override {
5982 if (m < 0) {
5983 solver()->Fail();
5984 }
5985 if (m == std::numeric_limits<int64_t>::max()) {
5986 return;
5987 }
5988 if (m < CapAdd(fixed_charge_, step_)) {
5989 expr_->SetMax(0);
5990 } else {
5991 const int64_t y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5992 expr_->SetMax(y);
5993 }
5994 }
5995
5996 std::string name() const override {
5997 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5998 expr_->name(), fixed_charge_, step_);
5999 }
6000
6001 std::string DebugString() const override {
6002 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
6003 expr_->DebugString(), fixed_charge_, step_);
6004 }
6005
6006 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6007
6008 void Accept(ModelVisitor* const visitor) const override {
6009 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6010 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6011 expr_);
6012 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6013 fixed_charge_);
6014 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
6015 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6016 }
6017
6018 private:
6019 IntExpr* const expr_;
6020 const int64_t fixed_charge_;
6021 const int64_t step_;
6022};
6023
6024class SemiContinuousStepOneExpr : public BaseIntExpr {
6025 public:
6026 SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
6027 int64_t fixed_charge)
6028 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6029 DCHECK_GE(fixed_charge, int64_t{0});
6030 }
6031
6032 ~SemiContinuousStepOneExpr() override {}
6033
6034 int64_t Value(int64_t x) const {
6035 if (x <= 0) {
6036 return 0;
6037 } else {
6038 return fixed_charge_ + x;
6039 }
6040 }
6041
6042 int64_t Min() const override { return Value(expr_->Min()); }
6043
6044 void SetMin(int64_t m) override {
6045 if (m >= fixed_charge_ + 1) {
6046 expr_->SetMin(m - fixed_charge_);
6047 } else if (m > 0) {
6048 expr_->SetMin(1);
6049 }
6050 }
6051
6052 int64_t Max() const override { return Value(expr_->Max()); }
6053
6054 void SetMax(int64_t m) override {
6055 if (m < 0) {
6056 solver()->Fail();
6057 }
6058 if (m < fixed_charge_ + 1) {
6059 expr_->SetMax(0);
6060 } else {
6061 expr_->SetMax(m - fixed_charge_);
6062 }
6063 }
6064
6065 std::string name() const override {
6066 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6067 expr_->name(), fixed_charge_);
6068 }
6069
6070 std::string DebugString() const override {
6071 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6072 expr_->DebugString(), fixed_charge_);
6073 }
6074
6075 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6076
6077 void Accept(ModelVisitor* const visitor) const override {
6078 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6079 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6080 expr_);
6081 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6082 fixed_charge_);
6083 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6084 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6085 }
6086
6087 private:
6088 IntExpr* const expr_;
6089 const int64_t fixed_charge_;
6090};
6091
6092class SemiContinuousStepZeroExpr : public BaseIntExpr {
6093 public:
6094 SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6095 int64_t fixed_charge)
6096 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6097 DCHECK_GT(fixed_charge, int64_t{0});
6098 }
6099
6100 ~SemiContinuousStepZeroExpr() override {}
6101
6102 int64_t Value(int64_t x) const {
6103 if (x <= 0) {
6104 return 0;
6105 } else {
6106 return fixed_charge_;
6107 }
6108 }
6109
6110 int64_t Min() const override { return Value(expr_->Min()); }
6111
6112 void SetMin(int64_t m) override {
6113 if (m >= fixed_charge_) {
6114 solver()->Fail();
6115 } else if (m > 0) {
6116 expr_->SetMin(1);
6117 }
6118 }
6119
6120 int64_t Max() const override { return Value(expr_->Max()); }
6121
6122 void SetMax(int64_t m) override {
6123 if (m < 0) {
6124 solver()->Fail();
6125 }
6126 if (m < fixed_charge_) {
6127 expr_->SetMax(0);
6128 }
6129 }
6130
6131 std::string name() const override {
6132 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6133 expr_->name(), fixed_charge_);
6134 }
6135
6136 std::string DebugString() const override {
6137 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6138 expr_->DebugString(), fixed_charge_);
6139 }
6140
6141 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6142
6143 void Accept(ModelVisitor* const visitor) const override {
6144 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6145 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6146 expr_);
6147 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6148 fixed_charge_);
6149 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6150 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6151 }
6152
6153 private:
6154 IntExpr* const expr_;
6155 const int64_t fixed_charge_;
6156};
6157
6158// This constraints links an expression and the variable it is casted into
6159class LinkExprAndVar : public CastConstraint {
6160 public:
6161 LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6162 : CastConstraint(s, var), expr_(expr) {}
6163
6164 ~LinkExprAndVar() override {}
6165
6166 void Post() override {
6167 Solver* const s = solver();
6168 Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6169 expr_->WhenRange(d);
6170 target_var_->WhenRange(d);
6171 }
6172
6173 void InitialPropagate() override {
6174 expr_->SetRange(target_var_->Min(), target_var_->Max());
6175 int64_t l, u;
6176 expr_->Range(&l, &u);
6177 target_var_->SetRange(l, u);
6178 }
6179
6180 std::string DebugString() const override {
6181 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6182 target_var_->DebugString());
6183 }
6184
6185 void Accept(ModelVisitor* const visitor) const override {
6186 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6187 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6188 expr_);
6189 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6190 target_var_);
6191 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6192 }
6193
6194 private:
6195 IntExpr* const expr_;
6196};
6197
6198// ----- Conditional Expression -----
6199
6200class ExprWithEscapeValue : public BaseIntExpr {
6201 public:
6202 ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6203 int64_t unperformed_value)
6204 : BaseIntExpr(s),
6205 condition_(c),
6206 expression_(e),
6207 unperformed_value_(unperformed_value) {}
6208
6209 // This type is neither copyable nor movable.
6210 ExprWithEscapeValue(const ExprWithEscapeValue&) = delete;
6211 ExprWithEscapeValue& operator=(const ExprWithEscapeValue&) = delete;
6212
6213 ~ExprWithEscapeValue() override {}
6214
6215 int64_t Min() const override {
6216 if (condition_->Min() == 1) {
6217 return expression_->Min();
6218 } else if (condition_->Max() == 1) {
6219 return std::min(unperformed_value_, expression_->Min());
6220 } else {
6221 return unperformed_value_;
6222 }
6223 }
6224
6225 void SetMin(int64_t m) override {
6226 if (m > unperformed_value_) {
6227 condition_->SetValue(1);
6228 expression_->SetMin(m);
6229 } else if (condition_->Min() == 1) {
6230 expression_->SetMin(m);
6231 } else if (m > expression_->Max()) {
6232 condition_->SetValue(0);
6233 }
6234 }
6235
6236 int64_t Max() const override {
6237 if (condition_->Min() == 1) {
6238 return expression_->Max();
6239 } else if (condition_->Max() == 1) {
6240 return std::max(unperformed_value_, expression_->Max());
6241 } else {
6242 return unperformed_value_;
6243 }
6244 }
6245
6246 void SetMax(int64_t m) override {
6247 if (m < unperformed_value_) {
6248 condition_->SetValue(1);
6249 expression_->SetMax(m);
6250 } else if (condition_->Min() == 1) {
6251 expression_->SetMax(m);
6252 } else if (m < expression_->Min()) {
6253 condition_->SetValue(0);
6254 }
6255 }
6256
6257 void SetRange(int64_t mi, int64_t ma) override {
6258 if (ma < unperformed_value_ || mi > unperformed_value_) {
6259 condition_->SetValue(1);
6260 expression_->SetRange(mi, ma);
6261 } else if (condition_->Min() == 1) {
6262 expression_->SetRange(mi, ma);
6263 } else if (ma < expression_->Min() || mi > expression_->Max()) {
6264 condition_->SetValue(0);
6265 }
6266 }
6267
6268 void SetValue(int64_t v) override {
6269 if (v != unperformed_value_) {
6270 condition_->SetValue(1);
6271 expression_->SetValue(v);
6272 } else if (condition_->Min() == 1) {
6273 expression_->SetValue(v);
6274 } else if (v < expression_->Min() || v > expression_->Max()) {
6275 condition_->SetValue(0);
6276 }
6277 }
6278
6279 bool Bound() const override {
6280 return condition_->Max() == 0 || expression_->Bound();
6281 }
6282
6283 void WhenRange(Demon* d) override {
6284 expression_->WhenRange(d);
6285 condition_->WhenBound(d);
6286 }
6287
6288 std::string DebugString() const override {
6289 return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6290 condition_->DebugString(),
6291 expression_->DebugString(), unperformed_value_);
6292 }
6293
6294 void Accept(ModelVisitor* const visitor) const override {
6295 visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6296 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6297 condition_);
6298 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6299 expression_);
6300 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6301 unperformed_value_);
6302 visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6303 }
6304
6305 private:
6306 IntVar* const condition_;
6307 IntExpr* const expression_;
6308 const int64_t unperformed_value_;
6309};
6310
6311// ----- This is a specialized case when the variable exact type is known -----
6312class LinkExprAndDomainIntVar : public CastConstraint {
6313 public:
6314 LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6315 DomainIntVar* const var)
6316 : CastConstraint(s, var),
6317 expr_(expr),
6318 cached_min_(std::numeric_limits<int64_t>::min()),
6319 cached_max_(std::numeric_limits<int64_t>::max()),
6320 fail_stamp_(uint64_t{0}) {}
6321
6322 ~LinkExprAndDomainIntVar() override {}
6323
6324 DomainIntVar* var() const {
6325 return reinterpret_cast<DomainIntVar*>(target_var_);
6326 }
6327
6328 void Post() override {
6329 Solver* const s = solver();
6330 Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6331 expr_->WhenRange(d);
6332 Demon* const target_var_demon = MakeConstraintDemon0(
6333 solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6334 target_var_->WhenRange(target_var_demon);
6335 }
6336
6337 void InitialPropagate() override {
6338 expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6339 expr_->Range(&cached_min_, &cached_max_);
6340 var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6341 }
6342
6343 void Propagate() {
6344 if (var()->min_.Value() > cached_min_ ||
6345 var()->max_.Value() < cached_max_ ||
6346 solver()->fail_stamp() != fail_stamp_) {
6347 InitialPropagate();
6348 fail_stamp_ = solver()->fail_stamp();
6349 }
6350 }
6351
6352 std::string DebugString() const override {
6353 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6354 target_var_->DebugString());
6355 }
6356
6357 void Accept(ModelVisitor* const visitor) const override {
6358 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6359 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6360 expr_);
6361 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6362 target_var_);
6363 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6364 }
6365
6366 private:
6367 IntExpr* const expr_;
6368 int64_t cached_min_;
6369 int64_t cached_max_;
6370 uint64_t fail_stamp_;
6371};
6372} // namespace
6373
6374// ----- Misc -----
6375
6377 return CondRevAlloc(solver(), reversible, new EmptyIterator());
6378}
6380 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
6381}
6382
6383// ----- API -----
6384
6386 DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6387 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6388 dvar->CleanInProcess();
6389}
6390
6391Constraint* SetIsEqual(IntVar* const var, absl::Span<const int64_t> values,
6392 const std::vector<IntVar*>& vars) {
6393 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6394 CHECK(dvar != nullptr);
6395 return dvar->SetIsEqual(values, vars);
6396}
6397
6399 absl::Span<const int64_t> values,
6400 const std::vector<IntVar*>& vars) {
6401 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6402 CHECK(dvar != nullptr);
6403 return dvar->SetIsGreaterOrEqual(values, vars);
6404}
6405
6407 DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6408 BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6409 boolean_var->RestoreValue();
6410}
6411
6412// ----- API -----
6413
6414IntVar* Solver::MakeIntVar(int64_t min, int64_t max, const std::string& name) {
6415 if (min == max) {
6416 return MakeIntConst(min, name);
6417 }
6418 if (min == 0 && max == 1) {
6419 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6420 } else if (CapSub(max, min) == 1) {
6421 const std::string inner_name = "inner_" + name;
6422 return RegisterIntVar(
6423 MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6424 ->VarWithName(name));
6425 } else {
6426 return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6427 }
6428}
6429
6430IntVar* Solver::MakeIntVar(int64_t min, int64_t max) {
6431 return MakeIntVar(min, max, "");
6432}
6433
6434IntVar* Solver::MakeBoolVar(const std::string& name) {
6435 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6436}
6437
6439 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6440}
6441
6442IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values,
6443 const std::string& name) {
6444 DCHECK(!values.empty());
6445 // Fast-track the case where we have a single value.
6446 if (values.size() == 1) return MakeIntConst(values[0], name);
6447 // Sort and remove duplicates.
6448 std::vector<int64_t> unique_sorted_values = values;
6449 gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6450 // Case when we have a single value, after clean-up.
6451 if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6452 // Case when the values are a dense interval of integers.
6453 if (unique_sorted_values.size() ==
6454 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6455 return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6456 name);
6457 }
6458 // Compute the GCD: if it's not 1, we can express the variable's domain as
6459 // the product of the GCD and of a domain with smaller values.
6460 int64_t gcd = 0;
6461 for (const int64_t v : unique_sorted_values) {
6462 if (gcd == 0) {
6463 gcd = std::abs(v);
6464 } else {
6465 gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6466 }
6467 if (gcd == 1) {
6468 // If it's 1, though, we can't do anything special, so we
6469 // immediately return a new DomainIntVar.
6470 return RegisterIntVar(
6471 RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6472 }
6473 }
6474 DCHECK_GT(gcd, 1);
6475 for (int64_t& v : unique_sorted_values) {
6476 DCHECK_EQ(0, v % gcd);
6477 v /= gcd;
6478 }
6479 const std::string new_name = name.empty() ? "" : "inner_" + name;
6480 // Catch the case where the divided values are a dense set of integers.
6481 IntVar* inner_intvar = nullptr;
6482 if (unique_sorted_values.size() ==
6483 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6484 inner_intvar = MakeIntVar(unique_sorted_values.front(),
6485 unique_sorted_values.back(), new_name);
6486 } else {
6487 inner_intvar = RegisterIntVar(
6488 RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6489 }
6490 return MakeProd(inner_intvar, gcd)->Var();
6491}
6492
6493IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values) {
6494 return MakeIntVar(values, "");
6495}
6496
6497IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6498 const std::string& name) {
6499 return MakeIntVar(ToInt64Vector(values), name);
6500}
6501
6502IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6503 return MakeIntVar(values, "");
6504}
6505
6506IntVar* Solver::MakeIntConst(int64_t val, const std::string& name) {
6507 // If IntConst is going to be named after its creation,
6508 // cp_share_int_consts should be set to false otherwise names can potentially
6509 // be overwritten.
6510 if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6511 val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6512 return cached_constants_[val - MIN_CACHED_INT_CONST];
6513 }
6514 return RevAlloc(new IntConst(this, val, name));
6515}
6516
6517IntVar* Solver::MakeIntConst(int64_t val) { return MakeIntConst(val, ""); }
6518
6519// ----- Int Var and associated methods -----
6520
6521namespace {
6522std::string IndexedName(absl::string_view prefix, int index, int max_index) {
6523#if 0
6524#if defined(_MSC_VER)
6525 const int digits = max_index > 0 ?
6526 static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6527 1;
6528#else
6529 const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6530#endif
6531 return absl::StrFormat("%s%0*d", prefix, digits, index);
6532#else
6533 return absl::StrCat(prefix, index);
6534#endif
6535}
6536} // namespace
6537
6538void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6539 const std::string& name,
6540 std::vector<IntVar*>* vars) {
6541 for (int i = 0; i < var_count; ++i) {
6542 vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6543 }
6544}
6545
6546void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6547 std::vector<IntVar*>* vars) {
6548 for (int i = 0; i < var_count; ++i) {
6549 vars->push_back(MakeIntVar(vmin, vmax));
6550 }
6551}
6552
6553IntVar** Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6554 const std::string& name) {
6555 IntVar** vars = new IntVar*[var_count];
6556 for (int i = 0; i < var_count; ++i) {
6557 vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6558 }
6559 return vars;
6560}
6561
6562void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6563 std::vector<IntVar*>* vars) {
6564 for (int i = 0; i < var_count; ++i) {
6565 vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6566 }
6567}
6568
6569void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6570 for (int i = 0; i < var_count; ++i) {
6571 vars->push_back(MakeBoolVar());
6572 }
6573}
6574
6575IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6576 IntVar** vars = new IntVar*[var_count];
6577 for (int i = 0; i < var_count; ++i) {
6578 vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6579 }
6580 return vars;
6581}
6582
6583void Solver::InitCachedIntConstants() {
6584 for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6585 cached_constants_[i - MIN_CACHED_INT_CONST] =
6586 RevAlloc(new IntConst(this, i, "")); // note the empty name
6587 }
6588}
6589
6590IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6591 CHECK_EQ(this, left->solver());
6592 CHECK_EQ(this, right->solver());
6593 if (right->Bound()) {
6594 return MakeSum(left, right->Min());
6595 }
6596 if (left->Bound()) {
6597 return MakeSum(right, left->Min());
6598 }
6599 if (left == right) {
6600 return MakeProd(left, 2);
6601 }
6602 IntExpr* cache = model_cache_->FindExprExprExpression(
6603 left, right, ModelCache::EXPR_EXPR_SUM);
6604 if (cache == nullptr) {
6605 cache = model_cache_->FindExprExprExpression(right, left,
6607 }
6608 if (cache != nullptr) {
6609 return cache;
6610 } else {
6611 IntExpr* const result =
6612 AddOverflows(left->Max(), right->Max()) ||
6613 AddOverflows(left->Min(), right->Min())
6614 ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6615 : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6616 model_cache_->InsertExprExprExpression(result, left, right,
6618 return result;
6619 }
6620}
6621
6622IntExpr* Solver::MakeSum(IntExpr* const expr, int64_t value) {
6623 CHECK_EQ(this, expr->solver());
6624 if (expr->Bound()) {
6625 return MakeIntConst(CapAdd(expr->Min(), value));
6626 }
6627 if (value == 0) {
6628 return expr;
6629 }
6630 IntExpr* result = Cache()->FindExprConstantExpression(
6631 expr, value, ModelCache::EXPR_CONSTANT_SUM);
6632 if (result == nullptr) {
6633 if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6634 !AddOverflows(value, expr->Min())) {
6635 IntVar* const var = expr->Var();
6636 switch (var->VarType()) {
6637 case DOMAIN_INT_VAR: {
6638 result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6639 this, reinterpret_cast<DomainIntVar*>(var), value)));
6640 break;
6641 }
6642 case CONST_VAR: {
6643 result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6644 break;
6645 }
6646 case VAR_ADD_CST: {
6647 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6648 IntVar* const sub_var = add_var->SubVar();
6649 const int64_t new_constant = value + add_var->Constant();
6650 if (new_constant == 0) {
6651 result = sub_var;
6652 } else {
6653 if (sub_var->VarType() == DOMAIN_INT_VAR) {
6654 DomainIntVar* const dvar =
6655 reinterpret_cast<DomainIntVar*>(sub_var);
6656 result = RegisterIntExpr(
6657 RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6658 } else {
6659 result = RegisterIntExpr(
6660 RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6661 }
6662 }
6663 break;
6664 }
6665 case CST_SUB_VAR: {
6666 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6667 IntVar* const sub_var = add_var->SubVar();
6668 const int64_t new_constant = value + add_var->Constant();
6669 result = RegisterIntExpr(
6670 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6671 break;
6672 }
6673 case OPP_VAR: {
6674 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6675 IntVar* const sub_var = add_var->SubVar();
6676 result =
6677 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6678 break;
6679 }
6680 default:
6681 result =
6682 RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6683 }
6684 } else {
6685 result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6686 }
6687 Cache()->InsertExprConstantExpression(result, expr, value,
6689 }
6690 return result;
6691}
6692
6693IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6694 CHECK_EQ(this, left->solver());
6695 CHECK_EQ(this, right->solver());
6696 if (left->Bound()) {
6697 return MakeDifference(left->Min(), right);
6698 }
6699 if (right->Bound()) {
6700 return MakeSum(left, -right->Min());
6701 }
6702 IntExpr* sub_left = nullptr;
6703 IntExpr* sub_right = nullptr;
6704 int64_t left_coef = 1;
6705 int64_t right_coef = 1;
6706 if (IsProduct(left, &sub_left, &left_coef) &&
6707 IsProduct(right, &sub_right, &right_coef)) {
6708 const int64_t abs_gcd =
6709 MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6710 if (abs_gcd != 0 && abs_gcd != 1) {
6711 return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6712 MakeProd(sub_right, right_coef / abs_gcd)),
6713 abs_gcd);
6714 }
6715 }
6716
6717 IntExpr* result = Cache()->FindExprExprExpression(
6719 if (result == nullptr) {
6720 if (!SubOverflows(left->Min(), right->Max()) &&
6721 !SubOverflows(left->Max(), right->Min())) {
6722 result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6723 } else {
6724 result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6725 }
6726 Cache()->InsertExprExprExpression(result, left, right,
6728 }
6729 return result;
6730}
6731
6732// warning: this is 'value - expr'.
6733IntExpr* Solver::MakeDifference(int64_t value, IntExpr* const expr) {
6734 CHECK_EQ(this, expr->solver());
6735 if (expr->Bound()) {
6736 return MakeIntConst(value - expr->Min());
6737 }
6738 if (value == 0) {
6739 return MakeOpposite(expr);
6740 }
6741 IntExpr* result = Cache()->FindExprConstantExpression(
6743 if (result == nullptr) {
6744 if (expr->IsVar() && expr->Min() != std::numeric_limits<int64_t>::min() &&
6745 !SubOverflows(value, expr->Min()) &&
6746 !SubOverflows(value, expr->Max())) {
6747 IntVar* const var = expr->Var();
6748 switch (var->VarType()) {
6749 case VAR_ADD_CST: {
6750 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6751 IntVar* const sub_var = add_var->SubVar();
6752 const int64_t new_constant = value - add_var->Constant();
6753 if (new_constant == 0) {
6754 result = sub_var;
6755 } else {
6756 result = RegisterIntExpr(
6757 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6758 }
6759 break;
6760 }
6761 case CST_SUB_VAR: {
6762 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6763 IntVar* const sub_var = add_var->SubVar();
6764 const int64_t new_constant = value - add_var->Constant();
6765 result = MakeSum(sub_var, new_constant);
6766 break;
6767 }
6768 case OPP_VAR: {
6769 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6770 IntVar* const sub_var = add_var->SubVar();
6771 result = MakeSum(sub_var, value);
6772 break;
6773 }
6774 default:
6775 result =
6776 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6777 }
6778 } else {
6779 result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6780 }
6781 Cache()->InsertExprConstantExpression(result, expr, value,
6783 }
6784 return result;
6785}
6786
6788 CHECK_EQ(this, expr->solver());
6789 if (expr->Bound()) {
6790 return MakeIntConst(CapOpp(expr->Min()));
6791 }
6792 IntExpr* result =
6793 Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6794 if (result == nullptr) {
6795 if (expr->IsVar()) {
6796 result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6797 } else {
6798 result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6799 }
6800 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6801 }
6802 return result;
6803}
6804
6805IntExpr* Solver::MakeProd(IntExpr* const expr, int64_t value) {
6806 CHECK_EQ(this, expr->solver());
6807 IntExpr* result = Cache()->FindExprConstantExpression(
6808 expr, value, ModelCache::EXPR_CONSTANT_PROD);
6809 if (result != nullptr) {
6810 return result;
6811 } else {
6812 IntExpr* m_expr = nullptr;
6813 int64_t coefficient = 1;
6814 if (IsProduct(expr, &m_expr, &coefficient)) {
6815 coefficient = CapProd(coefficient, value);
6816 } else {
6817 m_expr = expr;
6818 coefficient = value;
6819 }
6820 if (m_expr->Bound()) {
6821 return MakeIntConst(CapProd(coefficient, m_expr->Min()));
6822 } else if (coefficient == 1) {
6823 return m_expr;
6824 } else if (coefficient == -1) {
6825 return MakeOpposite(m_expr);
6826 } else if (coefficient > 0) {
6827 if (m_expr->Max() > std::numeric_limits<int64_t>::max() / coefficient ||
6828 m_expr->Min() < std::numeric_limits<int64_t>::min() / coefficient) {
6829 result = RegisterIntExpr(
6830 RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6831 } else {
6832 result = RegisterIntExpr(
6833 RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6834 }
6835 } else if (coefficient == 0) {
6836 result = MakeIntConst(0);
6837 } else { // coefficient < 0.
6838 result = RegisterIntExpr(
6839 RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6840 }
6841 if (m_expr->IsVar() &&
6842 !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6843 result = result->Var();
6844 }
6845 Cache()->InsertExprConstantExpression(result, expr, value,
6847 return result;
6848 }
6849}
6850
6851namespace {
6852void ExtractPower(IntExpr** const expr, int64_t* const exponant) {
6853 if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6854 BasePower* const power = dynamic_cast<BasePower*>(*expr);
6855 *expr = power->expr();
6856 *exponant = power->exponant();
6857 }
6858 if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6859 IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6860 *expr = power->expr();
6861 *exponant = 2;
6862 }
6863 if ((*expr)->IsVar()) {
6864 IntVar* const var = (*expr)->Var();
6865 IntExpr* const sub = var->solver()->CastExpression(var);
6866 if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6867 BasePower* const power = dynamic_cast<BasePower*>(sub);
6868 *expr = power->expr();
6869 *exponant = power->exponant();
6870 }
6871 if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6872 IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6873 *expr = power->expr();
6874 *exponant = 2;
6875 }
6876 }
6877}
6878
6879void ExtractProduct(IntExpr** const expr, int64_t* const coefficient,
6880 bool* modified) {
6881 if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6882 TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6883 *coefficient *= left_prod->Constant();
6884 *expr = left_prod->SubVar();
6885 *modified = true;
6886 } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6887 TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6888 *coefficient *= left_prod->Constant();
6889 *expr = left_prod->Expr();
6890 *modified = true;
6891 }
6892}
6893} // namespace
6894
6895IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6896 if (left->Bound()) {
6897 return MakeProd(right, left->Min());
6898 }
6899
6900 if (right->Bound()) {
6901 return MakeProd(left, right->Min());
6902 }
6903
6904 // ----- Discover squares and powers -----
6905
6906 IntExpr* m_left = left;
6907 IntExpr* m_right = right;
6908 int64_t left_exponant = 1;
6909 int64_t right_exponant = 1;
6910 ExtractPower(&m_left, &left_exponant);
6911 ExtractPower(&m_right, &right_exponant);
6912
6913 if (m_left == m_right) {
6914 return MakePower(m_left, left_exponant + right_exponant);
6915 }
6916
6917 // ----- Discover nested products -----
6918
6919 m_left = left;
6920 m_right = right;
6921 int64_t coefficient = 1;
6922 bool modified = false;
6923
6924 ExtractProduct(&m_left, &coefficient, &modified);
6925 ExtractProduct(&m_right, &coefficient, &modified);
6926 if (modified) {
6927 return MakeProd(MakeProd(m_left, m_right), coefficient);
6928 }
6929
6930 // ----- Standard build -----
6931
6932 CHECK_EQ(this, left->solver());
6933 CHECK_EQ(this, right->solver());
6934 IntExpr* result = model_cache_->FindExprExprExpression(
6935 left, right, ModelCache::EXPR_EXPR_PROD);
6936 if (result == nullptr) {
6937 result = model_cache_->FindExprExprExpression(right, left,
6939 }
6940 if (result != nullptr) {
6941 return result;
6942 }
6943 if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6944 if (right->Min() >= 0) {
6945 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6946 this, reinterpret_cast<BooleanVar*>(left), right)));
6947 } else {
6948 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6949 this, reinterpret_cast<BooleanVar*>(left), right)));
6950 }
6951 } else if (right->IsVar() &&
6952 reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6953 if (left->Min() >= 0) {
6954 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6955 this, reinterpret_cast<BooleanVar*>(right), left)));
6956 } else {
6957 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6958 this, reinterpret_cast<BooleanVar*>(right), left)));
6959 }
6960 } else if (left->Min() >= 0 && right->Min() >= 0) {
6961 if (CapProd(left->Max(), right->Max()) ==
6962 std::numeric_limits<int64_t>::max()) { // Potential overflow.
6963 result =
6964 RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6965 } else {
6966 result =
6967 RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6968 }
6969 } else {
6970 result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6971 }
6972 model_cache_->InsertExprExprExpression(result, left, right,
6974 return result;
6975}
6976
6977IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6978 CHECK(numerator != nullptr);
6979 CHECK(denominator != nullptr);
6980 if (denominator->Bound()) {
6981 return MakeDiv(numerator, denominator->Min());
6982 }
6983 IntExpr* result = model_cache_->FindExprExprExpression(
6984 numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6985 if (result != nullptr) {
6986 return result;
6987 }
6988
6989 if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6990 AddConstraint(MakeNonEquality(denominator, 0));
6991 }
6992
6993 if (denominator->Min() >= 0) {
6994 if (numerator->Min() >= 0) {
6995 result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6996 } else {
6997 result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
6998 }
6999 } else if (denominator->Max() <= 0) {
7000 if (numerator->Max() <= 0) {
7001 result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
7002 MakeOpposite(denominator)));
7003 } else {
7004 result = MakeOpposite(RevAlloc(
7005 new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
7006 }
7007 } else {
7008 result = RevAlloc(new DivIntExpr(this, numerator, denominator));
7009 }
7010 model_cache_->InsertExprExprExpression(result, numerator, denominator,
7012 return result;
7013}
7014
7015IntExpr* Solver::MakeDiv(IntExpr* const expr, int64_t value) {
7016 CHECK(expr != nullptr);
7017 CHECK_EQ(this, expr->solver());
7018 if (expr->Bound()) {
7019 return MakeIntConst(expr->Min() / value);
7020 } else if (value == 1) {
7021 return expr;
7022 } else if (value == -1) {
7023 return MakeOpposite(expr);
7024 } else if (value > 0) {
7025 return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
7026 } else if (value == 0) {
7027 LOG(FATAL) << "Cannot divide by 0";
7028 return nullptr;
7029 } else {
7030 return RegisterIntExpr(
7031 MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
7032 // TODO(user) : implement special case.
7033 }
7034}
7035
7036Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
7037 if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
7038 Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
7039 }
7040 return RevAlloc(new IntAbsConstraint(this, var, abs_var));
7041}
7042
7044 CHECK_EQ(this, e->solver());
7045 if (e->Min() >= 0) {
7046 return e;
7047 } else if (e->Max() <= 0) {
7048 return MakeOpposite(e);
7049 }
7050 IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
7051 if (result == nullptr) {
7052 int64_t coefficient = 1;
7053 IntExpr* expr = nullptr;
7054 if (IsProduct(e, &expr, &coefficient)) {
7055 result = MakeProd(MakeAbs(expr), std::abs(coefficient));
7056 } else {
7057 result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
7058 }
7059 Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7060 }
7061 return result;
7062}
7063
7065 CHECK_EQ(this, expr->solver());
7066 if (expr->Bound()) {
7067 const int64_t v = expr->Min();
7068 return MakeIntConst(v * v);
7069 }
7070 IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7071 if (result == nullptr) {
7072 if (expr->Min() >= 0) {
7073 result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7074 } else {
7075 result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7076 }
7077 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7078 }
7079 return result;
7080}
7081
7082IntExpr* Solver::MakePower(IntExpr* const expr, int64_t n) {
7083 CHECK_EQ(this, expr->solver());
7084 CHECK_GE(n, 0);
7085 if (expr->Bound()) {
7086 const int64_t v = expr->Min();
7087 if (v >= OverflowLimit(n)) { // Overflow.
7088 return MakeIntConst(std::numeric_limits<int64_t>::max());
7089 }
7090 return MakeIntConst(IntPower(v, n));
7091 }
7092 switch (n) {
7093 case 0:
7094 return MakeIntConst(1);
7095 case 1:
7096 return expr;
7097 case 2:
7098 return MakeSquare(expr);
7099 default: {
7100 IntExpr* result = nullptr;
7101 if (n % 2 == 0) { // even.
7102 if (expr->Min() >= 0) {
7103 result =
7104 RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7105 } else {
7106 result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7107 }
7108 } else {
7109 result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7110 }
7111 return result;
7112 }
7113 }
7114}
7115
7116IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7117 CHECK_EQ(this, left->solver());
7118 CHECK_EQ(this, right->solver());
7119 if (left->Bound()) {
7120 return MakeMin(right, left->Min());
7121 }
7122 if (right->Bound()) {
7123 return MakeMin(left, right->Min());
7124 }
7125 if (left->Min() >= right->Max()) {
7126 return right;
7127 }
7128 if (right->Min() >= left->Max()) {
7129 return left;
7130 }
7131 return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7132}
7133
7134IntExpr* Solver::MakeMin(IntExpr* const expr, int64_t value) {
7135 CHECK_EQ(this, expr->solver());
7136 if (value <= expr->Min()) {
7137 return MakeIntConst(value);
7138 }
7139 if (expr->Bound()) {
7140 return MakeIntConst(std::min(expr->Min(), value));
7141 }
7142 if (expr->Max() <= value) {
7143 return expr;
7144 }
7145 return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7146}
7147
7148IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7149 return MakeMin(expr, static_cast<int64_t>(value));
7150}
7151
7152IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7153 CHECK_EQ(this, left->solver());
7154 CHECK_EQ(this, right->solver());
7155 if (left->Bound()) {
7156 return MakeMax(right, left->Min());
7157 }
7158 if (right->Bound()) {
7159 return MakeMax(left, right->Min());
7160 }
7161 if (left->Min() >= right->Max()) {
7162 return left;
7163 }
7164 if (right->Min() >= left->Max()) {
7165 return right;
7166 }
7167 return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7168}
7169
7170IntExpr* Solver::MakeMax(IntExpr* const expr, int64_t value) {
7171 CHECK_EQ(this, expr->solver());
7172 if (expr->Bound()) {
7173 return MakeIntConst(std::max(expr->Min(), value));
7174 }
7175 if (value <= expr->Min()) {
7176 return expr;
7177 }
7178 if (expr->Max() <= value) {
7179 return MakeIntConst(value);
7180 }
7181 return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7182}
7183
7184IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7185 return MakeMax(expr, static_cast<int64_t>(value));
7186}
7187
7189 int64_t early_date, int64_t late_date,
7190 int64_t late_cost) {
7191 return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7192 this, expr, early_cost, early_date, late_date, late_cost)));
7193}
7194
7196 int64_t fixed_charge, int64_t step) {
7197 if (step == 0) {
7198 if (fixed_charge == 0) {
7199 return MakeIntConst(int64_t{0});
7200 } else {
7201 return RegisterIntExpr(
7202 RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7203 }
7204 } else if (step == 1) {
7205 return RegisterIntExpr(
7206 RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7207 } else {
7208 return RegisterIntExpr(
7209 RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7210 }
7211 // TODO(user) : benchmark with virtualization of
7212 // PosIntDivDown and PosIntDivUp - or function pointers.
7213}
7214
7215// ----- Piecewise Linear -----
7216
7218 public:
7220 const PiecewiseLinearFunction& f)
7221 : BaseIntExpr(solver), expr_(expr), f_(f) {}
7223 int64_t Min() const override {
7224 return f_.GetMinimum(expr_->Min(), expr_->Max());
7225 }
7226 void SetMin(int64_t m) override {
7227 const auto& range =
7228 f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7229 expr_->SetRange(range.first, range.second);
7230 }
7231
7232 int64_t Max() const override {
7233 return f_.GetMaximum(expr_->Min(), expr_->Max());
7234 }
7235
7236 void SetMax(int64_t m) override {
7237 const auto& range =
7238 f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7239 expr_->SetRange(range.first, range.second);
7240 }
7241
7242 void SetRange(int64_t l, int64_t u) override {
7243 const auto& range =
7244 f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7245 expr_->SetRange(range.first, range.second);
7246 }
7247 std::string name() const override {
7248 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7249 f_.DebugString());
7250 }
7251
7252 std::string DebugString() const override {
7253 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7254 f_.DebugString());
7255 }
7256
7257 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7258
7259 void Accept(ModelVisitor* const visitor) const override {
7260 // TODO(user): Implement visitor.
7261 }
7262
7263 private:
7264 IntExpr* const expr_;
7265 const PiecewiseLinearFunction f_;
7266};
7267
7272
7273// ----- Conditional Expression -----
7274
7276 IntExpr* const expr,
7277 int64_t unperformed_value) {
7278 if (condition->Min() == 1) {
7279 return expr;
7280 } else if (condition->Max() == 0) {
7281 return MakeIntConst(unperformed_value);
7282 } else {
7283 IntExpr* cache = Cache()->FindExprExprConstantExpression(
7284 condition, expr, unperformed_value,
7286 if (cache == nullptr) {
7287 cache = RevAlloc(
7288 new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7289 Cache()->InsertExprExprConstantExpression(
7290 cache, condition, expr, unperformed_value,
7292 }
7293 return cache;
7294 }
7295}
7296
7297// ----- Modulo -----
7298
7299IntExpr* Solver::MakeModulo(IntExpr* const x, int64_t mod) {
7300 IntVar* const result =
7301 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7302 if (mod >= 0) {
7303 AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7304 } else {
7305 AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7306 }
7307 return result;
7308}
7309
7311 if (mod->Bound()) {
7312 return MakeModulo(x, mod->Min());
7313 }
7314 IntVar* const result =
7315 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7316 AddConstraint(MakeLess(result, MakeAbs(mod)));
7318 return result;
7319}
7320
7321// --------- IntVar ---------
7322
7323int IntVar::VarType() const { return UNSPECIFIED; }
7324
7325void IntVar::RemoveValues(const std::vector<int64_t>& values) {
7326 // TODO(user): Check and maybe inline this code.
7327 const int size = values.size();
7328 DCHECK_GE(size, 0);
7329 switch (size) {
7330 case 0: {
7331 return;
7332 }
7333 case 1: {
7334 RemoveValue(values[0]);
7335 return;
7336 }
7337 case 2: {
7338 RemoveValue(values[0]);
7339 RemoveValue(values[1]);
7340 return;
7341 }
7342 case 3: {
7343 RemoveValue(values[0]);
7344 RemoveValue(values[1]);
7345 RemoveValue(values[2]);
7346 return;
7347 }
7348 default: {
7349 // 4 values, let's start doing some more clever things.
7350 // TODO(user) : Sort values!
7351 int start_index = 0;
7352 int64_t new_min = Min();
7353 if (values[start_index] <= new_min) {
7354 while (start_index < size - 1 &&
7355 values[start_index + 1] == values[start_index] + 1) {
7356 new_min = values[start_index + 1] + 1;
7357 start_index++;
7358 }
7359 }
7360 int end_index = size - 1;
7361 int64_t new_max = Max();
7362 if (values[end_index] >= new_max) {
7363 while (end_index > start_index + 1 &&
7364 values[end_index - 1] == values[end_index] - 1) {
7365 new_max = values[end_index - 1] - 1;
7366 end_index--;
7367 }
7368 }
7369 SetRange(new_min, new_max);
7370 for (int i = start_index; i <= end_index; ++i) {
7371 RemoveValue(values[i]);
7372 }
7373 }
7374 }
7375}
7376
7377void IntVar::Accept(ModelVisitor* const visitor) const {
7378 IntExpr* const casted = solver()->CastExpression(this);
7379 visitor->VisitIntegerVariable(this, casted);
7380}
7381
7382void IntVar::SetValues(const std::vector<int64_t>& values) {
7383 switch (values.size()) {
7384 case 0: {
7385 solver()->Fail();
7386 break;
7387 }
7388 case 1: {
7389 SetValue(values.back());
7390 break;
7391 }
7392 case 2: {
7393 if (Contains(values[0])) {
7394 if (Contains(values[1])) {
7395 const int64_t l = std::min(values[0], values[1]);
7396 const int64_t u = std::max(values[0], values[1]);
7397 SetRange(l, u);
7398 if (u > l + 1) {
7399 RemoveInterval(l + 1, u - 1);
7400 }
7401 } else {
7402 SetValue(values[0]);
7403 }
7404 } else {
7405 SetValue(values[1]);
7406 }
7407 break;
7408 }
7409 default: {
7410 // TODO(user): use a clean and safe SortedUniqueCopy() class
7411 // that uses a global, static shared (and locked) storage.
7412 // TODO(user): We could filter out values not in the var.
7413 std::vector<int64_t>& tmp = solver()->tmp_vector_;
7414 tmp.clear();
7415 tmp.insert(tmp.end(), values.begin(), values.end());
7416 std::sort(tmp.begin(), tmp.end());
7417 tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7418 const int size = tmp.size();
7419 const int64_t vmin = Min();
7420 const int64_t vmax = Max();
7421 int first = 0;
7422 int last = size - 1;
7423 if (tmp.front() > vmax || tmp.back() < vmin) {
7424 solver()->Fail();
7425 }
7426 // TODO(user) : We could find the first position >= vmin by dichotomy.
7427 while (tmp[first] < vmin || !Contains(tmp[first])) {
7428 ++first;
7429 if (first > last || tmp[first] > vmax) {
7430 solver()->Fail();
7431 }
7432 }
7433 while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7434 // Note that last >= first implies tmp[last] >= vmin.
7435 --last;
7436 }
7437 DCHECK_GE(last, first);
7438 SetRange(tmp[first], tmp[last]);
7439 while (first < last) {
7440 const int64_t start = tmp[first] + 1;
7441 const int64_t end = tmp[first + 1] - 1;
7442 if (start <= end) {
7443 RemoveInterval(start, end);
7444 }
7445 first++;
7446 }
7447 }
7448 }
7449}
7450// ---------- BaseIntExpr ---------
7451
7452void LinkVarExpr(Solver* s, IntExpr* expr, IntVar* var) {
7453 if (!var->Bound()) {
7454 if (var->VarType() == DOMAIN_INT_VAR) {
7455 DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7457 s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7458 } else {
7459 s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7460 expr);
7461 }
7462 }
7463}
7464
7466 if (var_ == nullptr) {
7467 solver()->SaveValue(reinterpret_cast<void**>(&var_));
7468 var_ = CastToVar();
7469 }
7470 return var_;
7471}
7472
7474 int64_t vmin, vmax;
7475 Range(&vmin, &vmax);
7476 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7477 LinkVarExpr(solver(), this, var);
7478 return var;
7479}
7480
7481// Discovery methods
7482bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7483 IntExpr** const right) {
7484 if (expr->IsVar()) {
7485 IntVar* const expr_var = expr->Var();
7486 expr = CastExpression(expr_var);
7487 }
7488 // This is a dynamic cast to check the type of expr.
7489 // It returns nullptr is expr is not a subclass of SubIntExpr.
7490 SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7491 if (sub_expr != nullptr) {
7492 *left = sub_expr->left();
7493 *right = sub_expr->right();
7494 return true;
7495 }
7496 return false;
7497}
7498
7499bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7500 bool* is_negated) const {
7501 if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7502 *inner_var = expr->Var();
7503 *is_negated = false;
7504 return true;
7505 } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7506 SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7507 if (sub_var != nullptr && sub_var->Constant() == 1 &&
7508 sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7509 *is_negated = true;
7510 *inner_var = sub_var->SubVar();
7511 return true;
7512 }
7513 }
7514 return false;
7515}
7516
7517bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7518 int64_t* coefficient) {
7519 if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7520 TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7521 *coefficient = var->Constant();
7522 *inner_expr = var->SubVar();
7523 return true;
7524 } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7525 TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7526 *coefficient = prod->Constant();
7527 *inner_expr = prod->Expr();
7528 return true;
7529 }
7530 *inner_expr = expr;
7531 *coefficient = 1;
7532 return false;
7533}
7534
7535} // namespace operations_research
IntVar * Var() override
Creates a variable from the expression.
bool Contains(int64_t v) const override
IntVarIterator * MakeDomainIterator(bool reversible) const override
IntVar * IsGreaterOrEqual(int64_t constant) override
IntVarIterator * MakeHoleIterator(bool reversible) const override
IntVar * IsEqual(int64_t constant) override
IsEqual.
void RemoveValue(int64_t v) override
This method removes the value 'v' from the domain of the variable.
IntVar * IsDifferent(int64_t constant) override
SimpleRevFIFO< Demon * > delayed_bound_demons_
void SetMax(int64_t m) override
void SetRange(int64_t mi, int64_t ma) override
This method sets both the min and the max of the expression.
void RemoveInterval(int64_t l, int64_t u) override
void SetMin(int64_t m) override
SimpleRevFIFO< Demon * > bound_demons_
uint64_t Size() const override
This method returns the number of values in the domain of the variable.
std::string DebugString() const override
void WhenBound(Demon *d) override
IntVar * IsLessOrEqual(int64_t constant) override
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
virtual Solver::DemonPriority priority() const
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetMax(int64_t m)=0
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual int64_t Min() const =0
virtual void SetMin(int64_t m)=0
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
virtual void Range(int64_t *l, int64_t *u)
IntVar * VarWithName(const std::string &name)
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
virtual void WhenBound(Demon *d)=0
virtual void WhenDomain(Demon *d)=0
virtual void SetValues(const std::vector< int64_t > &values)
This method intersects the current domain with the values in the array.
void Accept(ModelVisitor *visitor) const override
Accepts the given visitor.
virtual IntVar * IsDifferent(int64_t constant)=0
virtual int64_t OldMax() const =0
Returns the previous max.
virtual IntVar * IsLessOrEqual(int64_t constant)=0
virtual bool Contains(int64_t v) const =0
virtual int64_t Value() const =0
virtual int VarType() const
virtual void RemoveValue(int64_t v)=0
This method removes the value 'v' from the domain of the variable.
virtual uint64_t Size() const =0
This method returns the number of values in the domain of the variable.
virtual IntVar * IsGreaterOrEqual(int64_t constant)=0
virtual IntVar * IsEqual(int64_t constant)=0
IsEqual.
virtual void RemoveInterval(int64_t l, int64_t u)=0
virtual int64_t OldMin() const =0
Returns the previous min.
virtual void RemoveValues(const std::vector< int64_t > &values)
This method remove the values from the domain of the variable.
static int64_t GCD64(int64_t x, int64_t y)
Definition mathutil.h:107
virtual void VisitIntegerVariable(const IntVar *variable, IntExpr *delegate)
void SetRange(int64_t l, int64_t u) override
This method sets both the min and the max of the expression.
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
std::string name() const override
Object naming.
std::string DebugString() const override
virtual std::string name() const
Object naming.
IntExpr * MakeDiv(IntExpr *expr, int64_t value)
expr / value (integer division)
IntVar * MakeBoolVar(const std::string &name)
MakeBoolVar will create a variable with a {0, 1} domain.
Constraint * MakeNonEquality(IntExpr *left, IntExpr *right)
left != right
Definition range_cst.cc:569
IntExpr * MakeMax(const std::vector< IntVar * > &vars)
std::max(vars)
IntExpr * MakeDifference(IntExpr *left, IntExpr *right)
left - right
IntVar * MakeBoolVar()
MakeBoolVar will create a variable with a {0, 1} domain.
IntExpr * MakeSum(IntExpr *left, IntExpr *right)
left + right.
IntExpr * MakeMin(const std::vector< IntVar * > &vars)
std::min(vars)
void Fail()
Abandon the current branch in the search tree. A backtrack will follow.
@ VAR_PRIORITY
VAR_PRIORITY is between DELAYED_PRIORITY and NORMAL_PRIORITY.
@ OUTSIDE_SEARCH
Before search, after search.
IntVar * RegisterIntVar(IntVar *var)
Registers a new IntVar and wraps it inside a TraceIntVar if necessary.
Definition trace.cc:860
bool IsBooleanVar(IntExpr *expr, IntVar **inner_var, bool *is_negated) const
Constraint * MakeLess(IntExpr *left, IntExpr *right)
left < right
Definition range_cst.cc:551
ModelCache * Cache() const
Returns the cache of the model.
IntExpr * MakeOpposite(IntExpr *expr)
-expr
Constraint * MakeAbsEquality(IntVar *var, IntVar *abs_var)
Creates the constraint abs(var) == abs_var.
void MakeBoolVarArray(int var_count, const std::string &name, std::vector< IntVar * > *vars)
IntExpr * MakePiecewiseLinearExpr(IntExpr *expr, const PiecewiseLinearFunction &f)
expressions.
IntExpr * RegisterIntExpr(IntExpr *expr)
Registers a new IntExpr and wraps it inside a TraceIntExpr if necessary.
Definition trace.cc:848
IntExpr * MakeModulo(IntExpr *x, int64_t mod)
Modulo expression x % mod (with the python convention for modulo).
Constraint * MakeGreater(IntExpr *left, IntExpr *right)
left > right
Definition range_cst.cc:565
IntExpr * MakeAbs(IntExpr *expr)
expr
IntExpr * MakeSquare(IntExpr *expr)
expr * expr
Constraint * MakeBetweenCt(IntExpr *expr, int64_t l, int64_t u)
(l <= expr <= u)
Definition expr_cst.cc:928
void MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax, const std::string &name, std::vector< IntVar * > *vars)
IntVar * MakeIntVar(int64_t min, int64_t max, const std::string &name)
MakeIntVar will create the best range based int var for the bounds given.
bool IsProduct(IntExpr *expr, IntExpr **inner_expr, int64_t *coefficient)
IntExpr * MakeConvexPiecewiseExpr(IntExpr *expr, int64_t early_cost, int64_t early_date, int64_t late_date, int64_t late_cost)
Convex piecewise function.
IntExpr * MakePower(IntExpr *expr, int64_t n)
expr ^ n (n > 0)
IntExpr * MakeConditionalExpression(IntVar *condition, IntExpr *expr, int64_t unperformed_value)
Conditional Expr condition ? expr : unperformed_value.
IntExpr * MakeProd(IntExpr *left, IntExpr *right)
left * right
IntVar * MakeIntConst(int64_t val, const std::string &name)
IntConst will create a constant expression.
void AddConstraint(Constraint *c)
Adds the constraint 'c' to the model.
IntExpr * MakeSemiContinuousExpr(IntExpr *expr, int64_t fixed_charge, int64_t step)
void AddCastConstraint(CastConstraint *constraint, IntVar *target_var, IntExpr *expr)
ABSL_FLAG(bool, cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition map_util.h:94
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition stl_util.h:55
For infeasible and unbounded see Not checked if options check_solutions_if_inf_or_unbounded and the If options first_solution_only is false
Definition matchers.h:467
std::pair< double, double > Range
Definition statistics.h:27
dual_gradient T(y - `dual_solution`) class DiagonalTrustRegionProblemFromQp
std::function< int64_t(const Model &)> Value(IntegerVariable v)
Definition integer.h:1839
OR-Tools root namespace.
int64_t SubOverflows(int64_t x, int64_t y)
int64_t UnsafeMostSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
int64_t CapAdd(int64_t x, int64_t y)
void RestoreBoolValue(IntVar *var)
int64_t UnsafeLeastSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
int64_t CapSub(int64_t x, int64_t y)
Constraint * SetIsGreaterOrEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
ClosedInterval::Iterator end(ClosedInterval interval)
Constraint * SetIsEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
bool AddOverflows(int64_t x, int64_t y)
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
static const uint64_t kAllBits64
Definition bitset.h:36
void LinkVarExpr(Solver *s, IntExpr *expr, IntVar *var)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
int64_t CapProd(int64_t x, int64_t y)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
Definition utilities.cc:824
uint64_t OneRange64(uint64_t s, uint64_t e)
Definition bitset.h:288
uint32_t BitPos64(uint64_t pos)
Definition bitset.h:333
uint64_t BitCountRange64(const uint64_t *bitset, uint64_t start, uint64_t end)
uint64_t BitCount64(uint64_t n)
Definition bitset.h:45
bool IsBitSet64(const uint64_t *const bitset, uint64_t pos)
Definition bitset.h:349
uint64_t OneBit64(int pos)
Definition bitset.h:41
uint64_t BitOffset64(uint64_t pos)
Definition bitset.h:337
int64_t PosIntDivDown(int64_t e, int64_t v)
uint64_t BitLength64(uint64_t size)
Definition bitset.h:341
int LeastSignificantBitPosition64(uint64_t n)
Definition bitset.h:130
void CleanVariableOnFail(IntVar *var)
int64_t CapOpp(int64_t v)
int MostSignificantBitPosition64(uint64_t n)
Definition bitset.h:234
int64_t PosIntDivUp(int64_t e, int64_t v)
STL namespace.
false true
Definition numbers.cc:223
trees with all degrees equal w the current value of w
bool Next()