Google OR-Tools v9.14
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
expressions.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cmath>
16#include <cstdint>
17#include <cstdlib>
18#include <limits>
19#include <memory>
20#include <string>
21#include <utility>
22#include <vector>
23
24#include "absl/flags/flag.h"
25#include "absl/strings/str_cat.h"
26#include "absl/strings/str_format.h"
27#include "absl/strings/string_view.h"
28#include "absl/types/span.h"
34#include "ortools/util/bitset.h"
37
38ABSL_FLAG(bool, cp_disable_expression_optimization, false,
39 "Disable special optimization when creating expressions.");
40ABSL_FLAG(bool, cp_share_int_consts, true,
41 "Share IntConst's with the same value.");
42
43#if defined(_MSC_VER)
44#pragma warning(disable : 4351 4355)
45#endif
46
47namespace operations_research {
48
49// ---------- IntExpr ----------
50
51IntVar* IntExpr::VarWithName(const std::string& name) {
52 IntVar* const var = Var();
53 var->set_name(name);
54 return var;
55}
56
57// ---------- IntVar ----------
58
59IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
60
61IntVar::IntVar(Solver* const s, const std::string& name)
62 : IntExpr(s), index_(s->GetNewIntVarIndex()) {
64}
65
66// ----- Boolean variable -----
67
69
70void BooleanVar::SetMin(int64_t m) {
71 if (m <= 0) return;
72 if (m > 1) solver()->Fail();
73 SetValue(1);
74}
75
76void BooleanVar::SetMax(int64_t m) {
77 if (m >= 1) return;
78 if (m < 0) solver()->Fail();
79 SetValue(0);
80}
81
82void BooleanVar::SetRange(int64_t mi, int64_t ma) {
83 if (mi > 1 || ma < 0 || mi > ma) {
84 solver()->Fail();
85 }
86 if (mi == 1) {
87 SetValue(1);
88 } else if (ma == 0) {
89 SetValue(0);
90 }
91}
92
93void BooleanVar::RemoveValue(int64_t v) {
95 if (v == 0) {
96 SetValue(1);
97 } else if (v == 1) {
98 SetValue(0);
99 }
100 } else if (v == value_) {
101 solver()->Fail();
102 }
103}
104
105void BooleanVar::RemoveInterval(int64_t l, int64_t u) {
106 if (u < l) return;
107 if (l <= 0 && u >= 1) {
108 solver()->Fail();
109 } else if (l == 1) {
110 SetValue(0);
111 } else if (u == 0) {
112 SetValue(1);
113 }
114}
115
119 delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
120 } else {
121 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
122 }
123 }
124}
125
126uint64_t BooleanVar::Size() const {
127 return (1 + (value_ == kUnboundBooleanVarValue));
128}
129
130bool BooleanVar::Contains(int64_t v) const {
131 return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
132}
133
134IntVar* BooleanVar::IsEqual(int64_t constant) {
135 if (constant > 1 || constant < 0) {
136 return solver()->MakeIntConst(0);
137 }
138 if (constant == 1) {
139 return this;
140 } else { // constant == 0.
141 return solver()->MakeDifference(1, this)->Var();
142 }
143}
144
145IntVar* BooleanVar::IsDifferent(int64_t constant) {
146 if (constant > 1 || constant < 0) {
147 return solver()->MakeIntConst(1);
148 }
149 if (constant == 1) {
150 return solver()->MakeDifference(1, this)->Var();
151 } else { // constant == 0.
152 return this;
153 }
154}
155
157 if (constant > 1) {
158 return solver()->MakeIntConst(0);
159 } else if (constant <= 0) {
160 return solver()->MakeIntConst(1);
161 } else {
162 return this;
163 }
164}
165
167 if (constant < 0) {
168 return solver()->MakeIntConst(0);
169 } else if (constant >= 1) {
170 return solver()->MakeIntConst(1);
171 } else {
172 return IsEqual(0);
173 }
174}
175
176std::string BooleanVar::DebugString() const {
177 std::string out;
178 const std::string& var_name = name();
179 if (!var_name.empty()) {
180 out = var_name + "(";
181 } else {
182 out = "BooleanVar(";
183 }
184 switch (value_) {
185 case 0:
186 out += "0";
187 break;
188 case 1:
189 out += "1";
190 break;
192 out += "0 .. 1";
193 break;
194 }
195 out += ")";
196 return out;
197}
198
199namespace {
200// ---------- Subclasses of IntVar ----------
201
202// ----- Domain Int Var: base class for variables -----
203// It Contains bounds and a bitset representation of possible values.
204class DomainIntVar : public IntVar {
205 public:
206 // Utility classes
207 class BitSetIterator : public BaseObject {
208 public:
209 BitSetIterator(uint64_t* const bitset, int64_t omin)
210 : bitset_(bitset),
211 omin_(omin),
212 max_(std::numeric_limits<int64_t>::min()),
213 current_(std::numeric_limits<int64_t>::max()) {}
214
215 ~BitSetIterator() override {}
216
217 void Init(int64_t min, int64_t max) {
218 max_ = max;
219 current_ = min;
220 }
221
222 bool Ok() const { return current_ <= max_; }
223
224 int64_t Value() const { return current_; }
225
226 void Next() {
227 if (++current_ <= max_) {
229 bitset_, current_ - omin_, max_ - omin_) +
230 omin_;
231 }
232 }
233
234 std::string DebugString() const override { return "BitSetIterator"; }
235
236 private:
237 uint64_t* const bitset_;
238 const int64_t omin_;
239 int64_t max_;
240 int64_t current_;
241 };
242
243 class BitSet : public BaseObject {
244 public:
245 explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
246 ~BitSet() override {}
247
248 virtual int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) = 0;
249 virtual int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) = 0;
250 virtual bool Contains(int64_t val) const = 0;
251 virtual bool SetValue(int64_t val) = 0;
252 virtual bool RemoveValue(int64_t val) = 0;
253 virtual uint64_t Size() const = 0;
254 virtual void DelayRemoveValue(int64_t val) = 0;
255 virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
256 virtual void ClearRemovedValues() = 0;
257 virtual std::string pretty_DebugString(int64_t min, int64_t max) const = 0;
258 virtual BitSetIterator* MakeIterator() = 0;
259
260 void InitHoles() {
261 const uint64_t current_stamp = solver_->stamp();
262 if (holes_stamp_ < current_stamp) {
263 holes_.clear();
264 holes_stamp_ = current_stamp;
265 }
266 }
267
268 virtual void ClearHoles() { holes_.clear(); }
269
270 const std::vector<int64_t>& Holes() { return holes_; }
271
272 void AddHole(int64_t value) { holes_.push_back(value); }
273
274 int NumHoles() const {
275 return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
276 }
277
278 protected:
279 Solver* const solver_;
280
281 private:
282 std::vector<int64_t> holes_;
283 uint64_t holes_stamp_;
284 };
285
286 class QueueHandler : public Demon {
287 public:
288 explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
289 ~QueueHandler() override {}
290 void Run(Solver* const s) override {
291 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
292 var_->Process();
293 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
294 }
295 Solver::DemonPriority priority() const override {
297 }
298 std::string DebugString() const override {
299 return absl::StrFormat("Handler(%s)", var_->DebugString());
300 }
301
302 private:
303 DomainIntVar* const var_;
304 };
305
306 // Bounds and Value watchers
307
308 // This class stores the watchers variables attached to values. It is
309 // reversible and it helps maintaining the set of 'active' watchers
310 // (variables not bound to a single value).
311 template <class T>
312 class RevIntPtrMap {
313 public:
314 RevIntPtrMap(Solver* const solver, int64_t rmin, int64_t rmax)
315 : solver_(solver), range_min_(rmin), start_(0) {}
316
317 ~RevIntPtrMap() {}
318
319 bool Empty() const { return start_.Value() == elements_.size(); }
320
321 void SortActive() { std::sort(elements_.begin(), elements_.end()); }
322
323 // Access with value API.
324
325 // Add the pointer to the map attached to the given value.
326 void UnsafeRevInsert(int64_t value, T* elem) {
327 elements_.push_back(std::make_pair(value, elem));
328 if (solver_->state() != Solver::OUTSIDE_SEARCH) {
329 solver_->AddBacktrackAction(
330 [this, value](Solver* s) { Uninsert(value); }, false);
331 }
332 }
333
334 T* FindPtrOrNull(int64_t value, int* position) {
335 for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
336 if (elements_[pos].first == value) {
337 if (position != nullptr) *position = pos;
338 return At(pos).second;
339 }
340 }
341 return nullptr;
342 }
343
344 // Access map through the underlying vector.
345 void RemoveAt(int position) {
346 const int start = start_.Value();
347 DCHECK_GE(position, start);
348 DCHECK_LT(position, elements_.size());
349 if (position > start) {
350 // Swap the current element with the one at the start position, and
351 // increase start.
352 const std::pair<int64_t, T*> copy = elements_[start];
353 elements_[start] = elements_[position];
354 elements_[position] = copy;
355 }
356 start_.Incr(solver_);
357 }
358
359 const std::pair<int64_t, T*>& At(int position) const {
360 DCHECK_GE(position, start_.Value());
361 DCHECK_LT(position, elements_.size());
362 return elements_[position];
363 }
364
365 void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
366
367 int start() const { return start_.Value(); }
368 int end() const { return elements_.size(); }
369 // Number of active elements.
370 int Size() const { return elements_.size() - start_.Value(); }
371
372 // Removes the object permanently from the map.
373 void Uninsert(int64_t value) {
374 for (int pos = 0; pos < elements_.size(); ++pos) {
375 if (elements_[pos].first == value) {
376 DCHECK_GE(pos, start_.Value());
377 const int last = elements_.size() - 1;
378 if (pos != last) { // Swap the current with the last.
379 elements_[pos] = elements_.back();
380 }
381 elements_.pop_back();
382 return;
383 }
384 }
385 LOG(FATAL) << "The element should have been removed";
386 }
387
388 private:
389 Solver* const solver_;
390 const int64_t range_min_;
391 NumericalRev<int> start_;
392 std::vector<std::pair<int64_t, T*>> elements_;
393 };
394
395 // Base class for value watchers
396 class BaseValueWatcher : public Constraint {
397 public:
398 explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
399
400 ~BaseValueWatcher() override {}
401
402 virtual IntVar* GetOrMakeValueWatcher(int64_t value) = 0;
403
404 virtual void SetValueWatcher(IntVar* boolvar, int64_t value) = 0;
405 };
406
407 // This class monitors the domain of the variable and updates the
408 // IsEqual/IsDifferent boolean variables accordingly.
409 class ValueWatcher : public BaseValueWatcher {
410 public:
411 class WatchDemon : public Demon {
412 public:
413 WatchDemon(ValueWatcher* const watcher, int64_t value, IntVar* var)
414 : value_watcher_(watcher), value_(value), var_(var) {}
415 ~WatchDemon() override {}
416
417 void Run(Solver* const solver) override {
418 value_watcher_->ProcessValueWatcher(value_, var_);
419 }
420
421 private:
422 ValueWatcher* const value_watcher_;
423 const int64_t value_;
424 IntVar* const var_;
425 };
426
427 class VarDemon : public Demon {
428 public:
429 explicit VarDemon(ValueWatcher* const watcher)
430 : value_watcher_(watcher) {}
431
432 ~VarDemon() override {}
433
434 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
435
436 private:
437 ValueWatcher* const value_watcher_;
438 };
439
440 ValueWatcher(Solver* const solver, DomainIntVar* const variable)
441 : BaseValueWatcher(solver),
442 variable_(variable),
443 hole_iterator_(variable_->MakeHoleIterator(true)),
444 var_demon_(nullptr),
445 watchers_(solver, variable->Min(), variable->Max()) {}
446
447 ~ValueWatcher() override {}
448
449 IntVar* GetOrMakeValueWatcher(int64_t value) override {
450 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
451 if (watcher != nullptr) return watcher;
452 if (variable_->Contains(value)) {
453 if (variable_->Bound()) {
454 return solver()->MakeIntConst(1);
455 } else {
456 const std::string vname = variable_->HasName()
457 ? variable_->name()
458 : variable_->DebugString();
459 const std::string bname =
460 absl::StrFormat("Watch<%s == %d>", vname, value);
461 IntVar* const boolvar = solver()->MakeBoolVar(bname);
462 watchers_.UnsafeRevInsert(value, boolvar);
463 if (posted_.Switched()) {
464 boolvar->WhenBound(
465 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
466 var_demon_->desinhibit(solver());
467 }
468 return boolvar;
469 }
470 } else {
471 return variable_->solver()->MakeIntConst(0);
472 }
473 }
474
475 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
476 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
477 if (!boolvar->Bound()) {
478 watchers_.UnsafeRevInsert(value, boolvar);
479 if (posted_.Switched() && !boolvar->Bound()) {
480 boolvar->WhenBound(
481 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
482 var_demon_->desinhibit(solver());
483 }
484 }
485 }
486
487 void Post() override {
488 var_demon_ = solver()->RevAlloc(new VarDemon(this));
489 variable_->WhenDomain(var_demon_);
490 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
491 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
492 const int64_t value = w.first;
493 IntVar* const boolvar = w.second;
494 if (!boolvar->Bound() && variable_->Contains(value)) {
495 boolvar->WhenBound(
496 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
497 }
498 }
499 posted_.Switch(solver());
500 }
501
502 void InitialPropagate() override {
503 if (variable_->Bound()) {
504 VariableBound();
505 } else {
506 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
507 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
508 const int64_t value = w.first;
509 IntVar* const boolvar = w.second;
510 if (!variable_->Contains(value)) {
511 boolvar->SetValue(0);
512 watchers_.RemoveAt(pos);
513 } else {
514 if (boolvar->Bound()) {
515 ProcessValueWatcher(value, boolvar);
516 watchers_.RemoveAt(pos);
517 }
518 }
519 }
520 CheckInhibit();
521 }
522 }
523
524 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
525 if (boolvar->Min() == 0) {
526 if (variable_->Size() < 0xFFFFFF) {
527 variable_->RemoveValue(value);
528 } else {
529 // Delay removal.
530 solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
531 }
532 } else {
533 variable_->SetValue(value);
534 }
535 }
536
537 void ProcessVar() {
538 const int kSmallList = 16;
539 if (variable_->Bound()) {
540 VariableBound();
541 } else if (watchers_.Size() <= kSmallList ||
542 variable_->Min() != variable_->OldMin() ||
543 variable_->Max() != variable_->OldMax()) {
544 // Brute force loop for small numbers of watchers, or if the bounds have
545 // changed, which would have required a sort (n log(n)) anyway to take
546 // advantage of.
547 ScanWatchers();
548 CheckInhibit();
549 } else {
550 // If there is no bitset, then there are no holes.
551 // In that case, the two loops above should have performed all
552 // propagation. Otherwise, scan the remaining watchers.
553 BitSet* const bitset = variable_->bitset();
554 if (bitset != nullptr && !watchers_.Empty()) {
555 if (bitset->NumHoles() * 2 < watchers_.Size()) {
556 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
557 int pos = 0;
558 IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
559 if (boolvar != nullptr) {
560 boolvar->SetValue(0);
561 watchers_.RemoveAt(pos);
562 }
563 }
564 } else {
565 ScanWatchers();
566 }
567 }
568 CheckInhibit();
569 }
570 }
571
572 // Optimized case if the variable is bound.
573 void VariableBound() {
574 DCHECK(variable_->Bound());
575 const int64_t value = variable_->Min();
576 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
577 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
578 w.second->SetValue(w.first == value);
579 }
580 watchers_.RemoveAll();
581 var_demon_->inhibit(solver());
582 }
583
584 // Scans all the watchers to check and assign them.
585 void ScanWatchers() {
586 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
587 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
588 if (!variable_->Contains(w.first)) {
589 IntVar* const boolvar = w.second;
590 boolvar->SetValue(0);
591 watchers_.RemoveAt(pos);
592 }
593 }
594 }
595
596 // If the set of active watchers is empty, we can inhibit the demon on the
597 // main variable.
598 void CheckInhibit() {
599 if (watchers_.Empty()) {
600 var_demon_->inhibit(solver());
601 }
602 }
603
604 void Accept(ModelVisitor* const visitor) const override {
605 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
606 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
607 variable_);
608 std::vector<int64_t> all_coefficients;
609 std::vector<IntVar*> all_bool_vars;
610 for (int position = watchers_.start(); position < watchers_.end();
611 ++position) {
612 const std::pair<int64_t, IntVar*>& w = watchers_.At(position);
613 all_coefficients.push_back(w.first);
614 all_bool_vars.push_back(w.second);
615 }
616 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
617 all_bool_vars);
618 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
619 all_coefficients);
620 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
621 }
622
623 std::string DebugString() const override {
624 return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
625 }
626
627 private:
628 DomainIntVar* const variable_;
629 IntVarIterator* const hole_iterator_;
630 RevSwitch posted_;
631 Demon* var_demon_;
632 RevIntPtrMap<IntVar> watchers_;
633 };
634
635 // Optimized case for small maps.
636 class DenseValueWatcher : public BaseValueWatcher {
637 public:
638 class WatchDemon : public Demon {
639 public:
640 WatchDemon(DenseValueWatcher* const watcher, int64_t value, IntVar* var)
641 : value_watcher_(watcher), value_(value), var_(var) {}
642 ~WatchDemon() override {}
643
644 void Run(Solver* const solver) override {
645 value_watcher_->ProcessValueWatcher(value_, var_);
646 }
647
648 private:
649 DenseValueWatcher* const value_watcher_;
650 const int64_t value_;
651 IntVar* const var_;
652 };
653
654 class VarDemon : public Demon {
655 public:
656 explicit VarDemon(DenseValueWatcher* const watcher)
657 : value_watcher_(watcher) {}
658
659 ~VarDemon() override {}
660
661 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
662
663 private:
664 DenseValueWatcher* const value_watcher_;
665 };
666
667 DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
668 : BaseValueWatcher(solver),
669 variable_(variable),
670 hole_iterator_(variable_->MakeHoleIterator(true)),
671 var_demon_(nullptr),
672 offset_(variable->Min()),
673 watchers_(variable->Max() - variable->Min() + 1, nullptr),
674 active_watchers_(0) {}
675
676 ~DenseValueWatcher() override {}
677
678 IntVar* GetOrMakeValueWatcher(int64_t value) override {
679 const int64_t var_max = offset_ + watchers_.size() - 1; // Bad cast.
680 if (value < offset_ || value > var_max) {
681 return solver()->MakeIntConst(0);
682 }
683 const int index = value - offset_;
684 IntVar* const watcher = watchers_[index];
685 if (watcher != nullptr) return watcher;
686 if (variable_->Contains(value)) {
687 if (variable_->Bound()) {
688 return solver()->MakeIntConst(1);
689 } else {
690 const std::string vname = variable_->HasName()
691 ? variable_->name()
692 : variable_->DebugString();
693 const std::string bname =
694 absl::StrFormat("Watch<%s == %d>", vname, value);
695 IntVar* const boolvar = solver()->MakeBoolVar(bname);
696 RevInsert(index, boolvar);
697 if (posted_.Switched()) {
698 boolvar->WhenBound(
699 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
700 var_demon_->desinhibit(solver());
701 }
702 return boolvar;
703 }
704 } else {
705 return variable_->solver()->MakeIntConst(0);
706 }
707 }
708
709 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
710 const int index = value - offset_;
711 CHECK(watchers_[index] == nullptr);
712 if (!boolvar->Bound()) {
713 RevInsert(index, boolvar);
714 if (posted_.Switched() && !boolvar->Bound()) {
715 boolvar->WhenBound(
716 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
717 var_demon_->desinhibit(solver());
718 }
719 }
720 }
721
722 void Post() override {
723 var_demon_ = solver()->RevAlloc(new VarDemon(this));
724 variable_->WhenDomain(var_demon_);
725 for (int pos = 0; pos < watchers_.size(); ++pos) {
726 const int64_t value = pos + offset_;
727 IntVar* const boolvar = watchers_[pos];
728 if (boolvar != nullptr && !boolvar->Bound() &&
729 variable_->Contains(value)) {
730 boolvar->WhenBound(
731 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
732 }
733 }
734 posted_.Switch(solver());
735 }
736
737 void InitialPropagate() override {
738 if (variable_->Bound()) {
739 VariableBound();
740 } else {
741 for (int pos = 0; pos < watchers_.size(); ++pos) {
742 IntVar* const boolvar = watchers_[pos];
743 if (boolvar == nullptr) continue;
744 const int64_t value = pos + offset_;
745 if (!variable_->Contains(value)) {
746 boolvar->SetValue(0);
747 RevRemove(pos);
748 } else if (boolvar->Bound()) {
749 ProcessValueWatcher(value, boolvar);
750 RevRemove(pos);
751 }
752 }
753 if (active_watchers_.Value() == 0) {
754 var_demon_->inhibit(solver());
755 }
756 }
757 }
758
759 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
760 if (boolvar->Min() == 0) {
761 variable_->RemoveValue(value);
762 } else {
763 variable_->SetValue(value);
764 }
765 }
766
767 void ProcessVar() {
768 if (variable_->Bound()) {
769 VariableBound();
770 } else {
771 // Brute force loop for small numbers of watchers.
772 ScanWatchers();
773 if (active_watchers_.Value() == 0) {
774 var_demon_->inhibit(solver());
775 }
776 }
777 }
778
779 // Optimized case if the variable is bound.
780 void VariableBound() {
781 DCHECK(variable_->Bound());
782 const int64_t value = variable_->Min();
783 for (int pos = 0; pos < watchers_.size(); ++pos) {
784 IntVar* const boolvar = watchers_[pos];
785 if (boolvar != nullptr) {
786 boolvar->SetValue(pos + offset_ == value);
787 RevRemove(pos);
788 }
789 }
790 var_demon_->inhibit(solver());
791 }
792
793 // Scans all the watchers to check and assign them.
794 void ScanWatchers() {
795 const int64_t old_min_index = variable_->OldMin() - offset_;
796 const int64_t old_max_index = variable_->OldMax() - offset_;
797 const int64_t min_index = variable_->Min() - offset_;
798 const int64_t max_index = variable_->Max() - offset_;
799 for (int pos = old_min_index; pos < min_index; ++pos) {
800 IntVar* const boolvar = watchers_[pos];
801 if (boolvar != nullptr) {
802 boolvar->SetValue(0);
803 RevRemove(pos);
804 }
805 }
806 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
807 IntVar* const boolvar = watchers_[pos];
808 if (boolvar != nullptr) {
809 boolvar->SetValue(0);
810 RevRemove(pos);
811 }
812 }
813 BitSet* const bitset = variable_->bitset();
814 if (bitset != nullptr) {
815 if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
816 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
817 IntVar* const boolvar = watchers_[hole - offset_];
818 if (boolvar != nullptr) {
819 boolvar->SetValue(0);
820 RevRemove(hole - offset_);
821 }
822 }
823 } else {
824 for (int pos = min_index + 1; pos < max_index; ++pos) {
825 IntVar* const boolvar = watchers_[pos];
826 if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
827 boolvar->SetValue(0);
828 RevRemove(pos);
829 }
830 }
831 }
832 }
833 }
834
835 void RevRemove(int pos) {
836 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
837 watchers_[pos] = nullptr;
838 active_watchers_.Decr(solver());
839 }
840
841 void RevInsert(int pos, IntVar* boolvar) {
842 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
843 watchers_[pos] = boolvar;
844 active_watchers_.Incr(solver());
845 }
846
847 void Accept(ModelVisitor* const visitor) const override {
848 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
849 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
850 variable_);
851 std::vector<int64_t> all_coefficients;
852 std::vector<IntVar*> all_bool_vars;
853 for (int position = 0; position < watchers_.size(); ++position) {
854 if (watchers_[position] != nullptr) {
855 all_coefficients.push_back(position + offset_);
856 all_bool_vars.push_back(watchers_[position]);
857 }
858 }
859 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
860 all_bool_vars);
861 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
862 all_coefficients);
863 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
864 }
865
866 std::string DebugString() const override {
867 return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
868 }
869
870 private:
871 DomainIntVar* const variable_;
872 IntVarIterator* const hole_iterator_;
873 RevSwitch posted_;
874 Demon* var_demon_;
875 const int64_t offset_;
876 std::vector<IntVar*> watchers_;
877 NumericalRev<int> active_watchers_;
878 };
879
880 class BaseUpperBoundWatcher : public Constraint {
881 public:
882 explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
883
884 ~BaseUpperBoundWatcher() override {}
885
886 virtual IntVar* GetOrMakeUpperBoundWatcher(int64_t value) = 0;
887
888 virtual void SetUpperBoundWatcher(IntVar* boolvar, int64_t value) = 0;
889 };
890
891 // This class watches the bounds of the variable and updates the
892 // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
893 // accordingly.
894 class UpperBoundWatcher : public BaseUpperBoundWatcher {
895 public:
896 class WatchDemon : public Demon {
897 public:
898 WatchDemon(UpperBoundWatcher* const watcher, int64_t index,
899 IntVar* const var)
900 : value_watcher_(watcher), index_(index), var_(var) {}
901 ~WatchDemon() override {}
902
903 void Run(Solver* const solver) override {
904 value_watcher_->ProcessUpperBoundWatcher(index_, var_);
905 }
906
907 private:
908 UpperBoundWatcher* const value_watcher_;
909 const int64_t index_;
910 IntVar* const var_;
911 };
912
913 class VarDemon : public Demon {
914 public:
915 explicit VarDemon(UpperBoundWatcher* const watcher)
916 : value_watcher_(watcher) {}
917 ~VarDemon() override {}
918
919 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
920
921 private:
922 UpperBoundWatcher* const value_watcher_;
923 };
924
925 UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
926 : BaseUpperBoundWatcher(solver),
927 variable_(variable),
928 var_demon_(nullptr),
929 watchers_(solver, variable->Min(), variable->Max()),
930 start_(0),
931 end_(0),
932 sorted_(false) {}
933
934 ~UpperBoundWatcher() override {}
935
936 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
937 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
938 if (watcher != nullptr) {
939 return watcher;
940 }
941 if (variable_->Max() >= value) {
942 if (variable_->Min() >= value) {
943 return solver()->MakeIntConst(1);
944 } else {
945 const std::string vname = variable_->HasName()
946 ? variable_->name()
947 : variable_->DebugString();
948 const std::string bname =
949 absl::StrFormat("Watch<%s >= %d>", vname, value);
950 IntVar* const boolvar = solver()->MakeBoolVar(bname);
951 watchers_.UnsafeRevInsert(value, boolvar);
952 if (posted_.Switched()) {
953 boolvar->WhenBound(
954 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
955 var_demon_->desinhibit(solver());
956 sorted_ = false;
957 }
958 return boolvar;
959 }
960 } else {
961 return variable_->solver()->MakeIntConst(0);
962 }
963 }
964
965 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
966 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
967 watchers_.UnsafeRevInsert(value, boolvar);
968 if (posted_.Switched() && !boolvar->Bound()) {
969 boolvar->WhenBound(
970 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
971 var_demon_->desinhibit(solver());
972 sorted_ = false;
973 }
974 }
975
976 void Post() override {
977 const int kTooSmallToSort = 8;
978 var_demon_ = solver()->RevAlloc(new VarDemon(this));
979 variable_->WhenRange(var_demon_);
980
981 if (watchers_.Size() > kTooSmallToSort) {
982 watchers_.SortActive();
983 sorted_ = true;
984 start_.SetValue(solver(), watchers_.start());
985 end_.SetValue(solver(), watchers_.end() - 1);
986 }
987
988 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
989 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
990 IntVar* const boolvar = w.second;
991 const int64_t value = w.first;
992 if (!boolvar->Bound() && value > variable_->Min() &&
993 value <= variable_->Max()) {
994 boolvar->WhenBound(
995 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
996 }
997 }
998 posted_.Switch(solver());
999 }
1000
1001 void InitialPropagate() override {
1002 const int64_t var_min = variable_->Min();
1003 const int64_t var_max = variable_->Max();
1004 if (sorted_) {
1005 while (start_.Value() <= end_.Value()) {
1006 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1007 if (w.first <= var_min) {
1008 w.second->SetValue(1);
1009 start_.Incr(solver());
1010 } else {
1011 break;
1012 }
1013 }
1014 while (end_.Value() >= start_.Value()) {
1015 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1016 if (w.first > var_max) {
1017 w.second->SetValue(0);
1018 end_.Decr(solver());
1019 } else {
1020 break;
1021 }
1022 }
1023 for (int i = start_.Value(); i <= end_.Value(); ++i) {
1024 const std::pair<int64_t, IntVar*>& w = watchers_.At(i);
1025 if (w.second->Bound()) {
1026 ProcessUpperBoundWatcher(w.first, w.second);
1027 }
1028 }
1029 if (start_.Value() > end_.Value()) {
1030 var_demon_->inhibit(solver());
1031 }
1032 } else {
1033 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1034 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1035 const int64_t value = w.first;
1036 IntVar* const boolvar = w.second;
1037
1038 if (value <= var_min) {
1039 boolvar->SetValue(1);
1040 watchers_.RemoveAt(pos);
1041 } else if (value > var_max) {
1042 boolvar->SetValue(0);
1043 watchers_.RemoveAt(pos);
1044 } else if (boolvar->Bound()) {
1045 ProcessUpperBoundWatcher(value, boolvar);
1046 watchers_.RemoveAt(pos);
1047 }
1048 }
1049 }
1050 }
1051
1052 void Accept(ModelVisitor* const visitor) const override {
1053 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1054 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1055 variable_);
1056 std::vector<int64_t> all_coefficients;
1057 std::vector<IntVar*> all_bool_vars;
1058 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1059 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1060 all_coefficients.push_back(w.first);
1061 all_bool_vars.push_back(w.second);
1062 }
1063 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1064 all_bool_vars);
1065 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1066 all_coefficients);
1067 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1068 }
1069
1070 std::string DebugString() const override {
1071 return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1072 }
1073
1074 private:
1075 void ProcessUpperBoundWatcher(int64_t value, IntVar* const boolvar) {
1076 if (boolvar->Min() == 0) {
1077 variable_->SetMax(value - 1);
1078 } else {
1079 variable_->SetMin(value);
1080 }
1081 }
1082
1083 void ProcessVar() {
1084 const int64_t var_min = variable_->Min();
1085 const int64_t var_max = variable_->Max();
1086 if (sorted_) {
1087 while (start_.Value() <= end_.Value()) {
1088 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1089 if (w.first <= var_min) {
1090 w.second->SetValue(1);
1091 start_.Incr(solver());
1092 } else {
1093 break;
1094 }
1095 }
1096 while (end_.Value() >= start_.Value()) {
1097 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1098 if (w.first > var_max) {
1099 w.second->SetValue(0);
1100 end_.Decr(solver());
1101 } else {
1102 break;
1103 }
1104 }
1105 if (start_.Value() > end_.Value()) {
1106 var_demon_->inhibit(solver());
1107 }
1108 } else {
1109 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1110 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1111 const int64_t value = w.first;
1112 IntVar* const boolvar = w.second;
1113
1114 if (value <= var_min) {
1115 boolvar->SetValue(1);
1116 watchers_.RemoveAt(pos);
1117 } else if (value > var_max) {
1118 boolvar->SetValue(0);
1119 watchers_.RemoveAt(pos);
1120 }
1121 }
1122 if (watchers_.Empty()) {
1123 var_demon_->inhibit(solver());
1124 }
1125 }
1126 }
1127
1128 DomainIntVar* const variable_;
1129 RevSwitch posted_;
1130 Demon* var_demon_;
1131 RevIntPtrMap<IntVar> watchers_;
1132 NumericalRev<int> start_;
1133 NumericalRev<int> end_;
1134 bool sorted_;
1135 };
1136
1137 // Optimized case for small maps.
1138 class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1139 public:
1140 class WatchDemon : public Demon {
1141 public:
1142 WatchDemon(DenseUpperBoundWatcher* const watcher, int64_t value,
1143 IntVar* var)
1144 : value_watcher_(watcher), value_(value), var_(var) {}
1145 ~WatchDemon() override {}
1146
1147 void Run(Solver* const solver) override {
1148 value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1149 }
1150
1151 private:
1152 DenseUpperBoundWatcher* const value_watcher_;
1153 const int64_t value_;
1154 IntVar* const var_;
1155 };
1156
1157 class VarDemon : public Demon {
1158 public:
1159 explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1160 : value_watcher_(watcher) {}
1161
1162 ~VarDemon() override {}
1163
1164 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1165
1166 private:
1167 DenseUpperBoundWatcher* const value_watcher_;
1168 };
1169
1170 DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1171 : BaseUpperBoundWatcher(solver),
1172 variable_(variable),
1173 var_demon_(nullptr),
1174 offset_(variable->Min()),
1175 watchers_(variable->Max() - variable->Min() + 1, nullptr),
1176 active_watchers_(0) {}
1177
1178 ~DenseUpperBoundWatcher() override {}
1179
1180 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
1181 if (variable_->Max() >= value) {
1182 if (variable_->Min() >= value) {
1183 return solver()->MakeIntConst(1);
1184 } else {
1185 const std::string vname = variable_->HasName()
1186 ? variable_->name()
1187 : variable_->DebugString();
1188 const std::string bname =
1189 absl::StrFormat("Watch<%s >= %d>", vname, value);
1190 IntVar* const boolvar = solver()->MakeBoolVar(bname);
1191 RevInsert(value - offset_, boolvar);
1192 if (posted_.Switched()) {
1193 boolvar->WhenBound(
1194 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1195 var_demon_->desinhibit(solver());
1196 }
1197 return boolvar;
1198 }
1199 } else {
1200 return variable_->solver()->MakeIntConst(0);
1201 }
1202 }
1203
1204 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
1205 const int index = value - offset_;
1206 CHECK(watchers_[index] == nullptr);
1207 if (!boolvar->Bound()) {
1208 RevInsert(index, boolvar);
1209 if (posted_.Switched() && !boolvar->Bound()) {
1210 boolvar->WhenBound(
1211 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1212 var_demon_->desinhibit(solver());
1213 }
1214 }
1215 }
1216
1217 void Post() override {
1218 var_demon_ = solver()->RevAlloc(new VarDemon(this));
1219 variable_->WhenRange(var_demon_);
1220 for (int pos = 0; pos < watchers_.size(); ++pos) {
1221 const int64_t value = pos + offset_;
1222 IntVar* const boolvar = watchers_[pos];
1223 if (boolvar != nullptr && !boolvar->Bound() &&
1224 value > variable_->Min() && value <= variable_->Max()) {
1225 boolvar->WhenBound(
1226 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1227 }
1228 }
1229 posted_.Switch(solver());
1230 }
1231
1232 void InitialPropagate() override {
1233 for (int pos = 0; pos < watchers_.size(); ++pos) {
1234 IntVar* const boolvar = watchers_[pos];
1235 if (boolvar == nullptr) continue;
1236 const int64_t value = pos + offset_;
1237 if (value <= variable_->Min()) {
1238 boolvar->SetValue(1);
1239 RevRemove(pos);
1240 } else if (value > variable_->Max()) {
1241 boolvar->SetValue(0);
1242 RevRemove(pos);
1243 } else if (boolvar->Bound()) {
1244 ProcessUpperBoundWatcher(value, boolvar);
1245 RevRemove(pos);
1246 }
1247 }
1248 if (active_watchers_.Value() == 0) {
1249 var_demon_->inhibit(solver());
1250 }
1251 }
1252
1253 void ProcessUpperBoundWatcher(int64_t value, IntVar* boolvar) {
1254 if (boolvar->Min() == 0) {
1255 variable_->SetMax(value - 1);
1256 } else {
1257 variable_->SetMin(value);
1258 }
1259 }
1260
1261 void ProcessVar() {
1262 const int64_t old_min_index = variable_->OldMin() - offset_;
1263 const int64_t old_max_index = variable_->OldMax() - offset_;
1264 const int64_t min_index = variable_->Min() - offset_;
1265 const int64_t max_index = variable_->Max() - offset_;
1266 for (int pos = old_min_index; pos <= min_index; ++pos) {
1267 IntVar* const boolvar = watchers_[pos];
1268 if (boolvar != nullptr) {
1269 boolvar->SetValue(1);
1270 RevRemove(pos);
1271 }
1272 }
1273
1274 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1275 IntVar* const boolvar = watchers_[pos];
1276 if (boolvar != nullptr) {
1277 boolvar->SetValue(0);
1278 RevRemove(pos);
1279 }
1280 }
1281 if (active_watchers_.Value() == 0) {
1282 var_demon_->inhibit(solver());
1283 }
1284 }
1285
1286 void RevRemove(int pos) {
1287 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1288 watchers_[pos] = nullptr;
1289 active_watchers_.Decr(solver());
1290 }
1291
1292 void RevInsert(int pos, IntVar* boolvar) {
1293 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1294 watchers_[pos] = boolvar;
1295 active_watchers_.Incr(solver());
1296 }
1297
1298 void Accept(ModelVisitor* const visitor) const override {
1299 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1300 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1301 variable_);
1302 std::vector<int64_t> all_coefficients;
1303 std::vector<IntVar*> all_bool_vars;
1304 for (int position = 0; position < watchers_.size(); ++position) {
1305 if (watchers_[position] != nullptr) {
1306 all_coefficients.push_back(position + offset_);
1307 all_bool_vars.push_back(watchers_[position]);
1308 }
1309 }
1310 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1311 all_bool_vars);
1312 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1313 all_coefficients);
1314 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1315 }
1316
1317 std::string DebugString() const override {
1318 return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1319 variable_->DebugString());
1320 }
1321
1322 private:
1323 DomainIntVar* const variable_;
1324 RevSwitch posted_;
1325 Demon* var_demon_;
1326 const int64_t offset_;
1327 std::vector<IntVar*> watchers_;
1328 NumericalRev<int> active_watchers_;
1329 };
1330
1331 // ----- Main Class -----
1332 DomainIntVar(Solver* s, int64_t vmin, int64_t vmax, const std::string& name);
1333 DomainIntVar(Solver* s, absl::Span<const int64_t> sorted_values,
1334 const std::string& name);
1335 ~DomainIntVar() override;
1336
1337 int64_t Min() const override { return min_.Value(); }
1338 void SetMin(int64_t m) override;
1339 int64_t Max() const override { return max_.Value(); }
1340 void SetMax(int64_t m) override;
1341 void SetRange(int64_t mi, int64_t ma) override;
1342 void SetValue(int64_t v) override;
1343 bool Bound() const override { return (min_.Value() == max_.Value()); }
1344 int64_t Value() const override {
1345 CHECK_EQ(min_.Value(), max_.Value())
1346 << " variable " << DebugString() << " is not bound.";
1347 return min_.Value();
1348 }
1349 void RemoveValue(int64_t v) override;
1350 void RemoveInterval(int64_t l, int64_t u) override;
1351 void CreateBits();
1352 void WhenBound(Demon* d) override {
1353 if (min_.Value() != max_.Value()) {
1354 if (d->priority() == Solver::DELAYED_PRIORITY) {
1355 delayed_bound_demons_.PushIfNotTop(solver(),
1356 solver()->RegisterDemon(d));
1357 } else {
1358 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1359 }
1360 }
1361 }
1362 void WhenRange(Demon* d) override {
1363 if (min_.Value() != max_.Value()) {
1364 if (d->priority() == Solver::DELAYED_PRIORITY) {
1365 delayed_range_demons_.PushIfNotTop(solver(),
1366 solver()->RegisterDemon(d));
1367 } else {
1368 range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1369 }
1370 }
1371 }
1372 void WhenDomain(Demon* d) override {
1373 if (min_.Value() != max_.Value()) {
1374 if (d->priority() == Solver::DELAYED_PRIORITY) {
1375 delayed_domain_demons_.PushIfNotTop(solver(),
1376 solver()->RegisterDemon(d));
1377 } else {
1378 domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1379 }
1380 }
1381 }
1382
1383 IntVar* IsEqual(int64_t constant) override {
1384 Solver* const s = solver();
1385 if (constant == min_.Value() && value_watcher_ == nullptr) {
1386 return s->MakeIsLessOrEqualCstVar(this, constant);
1387 }
1388 if (constant == max_.Value() && value_watcher_ == nullptr) {
1389 return s->MakeIsGreaterOrEqualCstVar(this, constant);
1390 }
1391 if (!Contains(constant)) {
1392 return s->MakeIntConst(int64_t{0});
1393 }
1394 if (Bound() && min_.Value() == constant) {
1395 return s->MakeIntConst(int64_t{1});
1396 }
1397 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1398 this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1399 if (cache != nullptr) {
1400 return cache->Var();
1401 } else {
1402 if (value_watcher_ == nullptr) {
1403 if (CapSub(Max(), Min()) <= 256) {
1404 solver()->SaveAndSetValue(
1405 reinterpret_cast<void**>(&value_watcher_),
1406 reinterpret_cast<void*>(
1407 solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1408
1409 } else {
1410 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1411 reinterpret_cast<void*>(solver()->RevAlloc(
1412 new ValueWatcher(solver(), this))));
1413 }
1414 solver()->AddConstraint(value_watcher_);
1415 }
1416 IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1417 s->Cache()->InsertExprConstantExpression(
1418 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1419 return boolvar;
1420 }
1421 }
1422
1423 Constraint* SetIsEqual(absl::Span<const int64_t> values,
1424 const std::vector<IntVar*>& vars) {
1425 if (value_watcher_ == nullptr) {
1426 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1427 reinterpret_cast<void*>(solver()->RevAlloc(
1428 new ValueWatcher(solver(), this))));
1429 for (int i = 0; i < vars.size(); ++i) {
1430 value_watcher_->SetValueWatcher(vars[i], values[i]);
1431 }
1432 }
1433 return value_watcher_;
1434 }
1435
1436 IntVar* IsDifferent(int64_t constant) override {
1437 Solver* const s = solver();
1438 if (constant == min_.Value() && value_watcher_ == nullptr) {
1439 return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1440 }
1441 if (constant == max_.Value() && value_watcher_ == nullptr) {
1442 return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1443 }
1444 if (!Contains(constant)) {
1445 return s->MakeIntConst(int64_t{1});
1446 }
1447 if (Bound() && min_.Value() == constant) {
1448 return s->MakeIntConst(int64_t{0});
1449 }
1450 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1452 if (cache != nullptr) {
1453 return cache->Var();
1454 } else {
1455 IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1456 s->Cache()->InsertExprConstantExpression(
1457 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1458 return boolvar;
1459 }
1460 }
1461
1462 IntVar* IsGreaterOrEqual(int64_t constant) override {
1463 Solver* const s = solver();
1464 if (max_.Value() < constant) {
1465 return s->MakeIntConst(int64_t{0});
1466 }
1467 if (min_.Value() >= constant) {
1468 return s->MakeIntConst(int64_t{1});
1469 }
1470 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1472 if (cache != nullptr) {
1473 return cache->Var();
1474 } else {
1475 if (bound_watcher_ == nullptr) {
1476 if (CapSub(Max(), Min()) <= 256) {
1477 solver()->SaveAndSetValue(
1478 reinterpret_cast<void**>(&bound_watcher_),
1479 reinterpret_cast<void*>(solver()->RevAlloc(
1480 new DenseUpperBoundWatcher(solver(), this))));
1481 solver()->AddConstraint(bound_watcher_);
1482 } else {
1483 solver()->SaveAndSetValue(
1484 reinterpret_cast<void**>(&bound_watcher_),
1485 reinterpret_cast<void*>(
1486 solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1487 solver()->AddConstraint(bound_watcher_);
1488 }
1489 }
1490 IntVar* const boolvar =
1491 bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1492 s->Cache()->InsertExprConstantExpression(
1493 boolvar, this, constant,
1495 return boolvar;
1496 }
1497 }
1498
1499 Constraint* SetIsGreaterOrEqual(absl::Span<const int64_t> values,
1500 const std::vector<IntVar*>& vars) {
1501 if (bound_watcher_ == nullptr) {
1502 if (CapSub(Max(), Min()) <= 256) {
1503 solver()->SaveAndSetValue(
1504 reinterpret_cast<void**>(&bound_watcher_),
1505 reinterpret_cast<void*>(solver()->RevAlloc(
1506 new DenseUpperBoundWatcher(solver(), this))));
1507 solver()->AddConstraint(bound_watcher_);
1508 } else {
1509 solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1510 reinterpret_cast<void*>(solver()->RevAlloc(
1511 new UpperBoundWatcher(solver(), this))));
1512 solver()->AddConstraint(bound_watcher_);
1513 }
1514 for (int i = 0; i < values.size(); ++i) {
1515 bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1516 }
1517 }
1518 return bound_watcher_;
1519 }
1520
1521 IntVar* IsLessOrEqual(int64_t constant) override {
1522 Solver* const s = solver();
1523 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1525 if (cache != nullptr) {
1526 return cache->Var();
1527 } else {
1528 IntVar* const boolvar =
1529 s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1530 s->Cache()->InsertExprConstantExpression(
1531 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1532 return boolvar;
1533 }
1534 }
1535
1536 void Process();
1537 void Push();
1538 void CleanInProcess();
1539 uint64_t Size() const override {
1540 if (bits_ != nullptr) return bits_->Size();
1541 return (static_cast<uint64_t>(max_.Value()) -
1542 static_cast<uint64_t>(min_.Value()) + 1);
1543 }
1544 bool Contains(int64_t v) const override {
1545 if (v < min_.Value() || v > max_.Value()) return false;
1546 return (bits_ == nullptr ? true : bits_->Contains(v));
1547 }
1548 IntVarIterator* MakeHoleIterator(bool reversible) const override;
1549 IntVarIterator* MakeDomainIterator(bool reversible) const override;
1550 int64_t OldMin() const override { return std::min(old_min_, min_.Value()); }
1551 int64_t OldMax() const override { return std::max(old_max_, max_.Value()); }
1552
1553 std::string DebugString() const override;
1554 BitSet* bitset() const { return bits_; }
1555 int VarType() const override { return DOMAIN_INT_VAR; }
1556 std::string BaseName() const override { return "IntegerVar"; }
1557
1558 friend class PlusCstDomainIntVar;
1559 friend class LinkExprAndDomainIntVar;
1560
1561 private:
1562 void CheckOldMin() {
1563 if (old_min_ > min_.Value()) {
1564 old_min_ = min_.Value();
1565 }
1566 }
1567 void CheckOldMax() {
1568 if (old_max_ < max_.Value()) {
1569 old_max_ = max_.Value();
1570 }
1571 }
1572 Rev<int64_t> min_;
1573 Rev<int64_t> max_;
1574 int64_t old_min_;
1575 int64_t old_max_;
1576 int64_t new_min_;
1577 int64_t new_max_;
1578 SimpleRevFIFO<Demon*> bound_demons_;
1579 SimpleRevFIFO<Demon*> range_demons_;
1580 SimpleRevFIFO<Demon*> domain_demons_;
1581 SimpleRevFIFO<Demon*> delayed_bound_demons_;
1582 SimpleRevFIFO<Demon*> delayed_range_demons_;
1583 SimpleRevFIFO<Demon*> delayed_domain_demons_;
1584 QueueHandler handler_;
1585 bool in_process_;
1586 BitSet* bits_;
1587 BaseValueWatcher* value_watcher_;
1588 BaseUpperBoundWatcher* bound_watcher_;
1589};
1590
1591// ----- BitSet -----
1592
1593// Return whether an integer interval [a..b] (inclusive) contains at most
1594// K values, i.e. b - a < K, in a way that's robust to overflows.
1595// For performance reasons, in opt mode it doesn't check that [a, b] is a
1596// valid interval, nor that K is nonnegative.
1597inline bool ClosedIntervalNoLargerThan(int64_t a, int64_t b, int64_t K) {
1598 DCHECK_LE(a, b);
1599 DCHECK_GE(K, 0);
1600 if (a > 0) {
1601 return a > b - K;
1602 } else {
1603 return a + K > b;
1604 }
1605}
1606
1607class SimpleBitSet : public DomainIntVar::BitSet {
1608 public:
1609 SimpleBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1610 : BitSet(s),
1611 bits_(nullptr),
1612 stamps_(nullptr),
1613 omin_(vmin),
1614 omax_(vmax),
1615 size_(vmax - vmin + 1),
1616 bsize_(BitLength64(size_.Value())) {
1617 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1618 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1619 bits_ = new uint64_t[bsize_];
1620 stamps_ = new uint64_t[bsize_];
1621 for (int i = 0; i < bsize_; ++i) {
1622 const int bs =
1623 (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1624 bits_[i] = kAllBits64 >> bs;
1625 stamps_[i] = s->stamp() - 1;
1626 }
1627 }
1628
1629 SimpleBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1630 int64_t vmin, int64_t vmax)
1631 : BitSet(s),
1632 bits_(nullptr),
1633 stamps_(nullptr),
1634 omin_(vmin),
1635 omax_(vmax),
1636 size_(sorted_values.size()),
1637 bsize_(BitLength64(vmax - vmin + 1)) {
1638 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1639 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1640 bits_ = new uint64_t[bsize_];
1641 stamps_ = new uint64_t[bsize_];
1642 for (int i = 0; i < bsize_; ++i) {
1643 bits_[i] = uint64_t{0};
1644 stamps_[i] = s->stamp() - 1;
1645 }
1646 for (int i = 0; i < sorted_values.size(); ++i) {
1647 const int64_t val = sorted_values[i];
1648 DCHECK(!bit(val));
1649 const int offset = BitOffset64(val - omin_);
1650 const int pos = BitPos64(val - omin_);
1651 bits_[offset] |= OneBit64(pos);
1652 }
1653 }
1654
1655 ~SimpleBitSet() override {
1656 delete[] bits_;
1657 delete[] stamps_;
1658 }
1659
1660 bool bit(int64_t val) const { return IsBitSet64(bits_, val - omin_); }
1661
1662 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1663 DCHECK_GE(nmin, cmin);
1664 DCHECK_LE(nmin, cmax);
1665 DCHECK_LE(cmin, cmax);
1666 DCHECK_GE(cmin, omin_);
1667 DCHECK_LE(cmax, omax_);
1668 const int64_t new_min =
1669 UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1670 omin_;
1671 const uint64_t removed_bits =
1672 BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1673 size_.Add(solver_, -removed_bits);
1674 return new_min;
1675 }
1676
1677 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1678 DCHECK_GE(nmax, cmin);
1679 DCHECK_LE(nmax, cmax);
1680 DCHECK_LE(cmin, cmax);
1681 DCHECK_GE(cmin, omin_);
1682 DCHECK_LE(cmax, omax_);
1683 const int64_t new_max =
1684 UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1685 omin_;
1686 const uint64_t removed_bits =
1687 BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1688 size_.Add(solver_, -removed_bits);
1689 return new_max;
1690 }
1691
1692 bool SetValue(int64_t val) override {
1693 DCHECK_GE(val, omin_);
1694 DCHECK_LE(val, omax_);
1695 if (bit(val)) {
1696 size_.SetValue(solver_, 1);
1697 return true;
1698 }
1699 return false;
1700 }
1701
1702 bool Contains(int64_t val) const override {
1703 DCHECK_GE(val, omin_);
1704 DCHECK_LE(val, omax_);
1705 return bit(val);
1706 }
1707
1708 bool RemoveValue(int64_t val) override {
1709 if (val < omin_ || val > omax_ || !bit(val)) {
1710 return false;
1711 }
1712 // Bitset.
1713 const int64_t val_offset = val - omin_;
1714 const int offset = BitOffset64(val_offset);
1715 const uint64_t current_stamp = solver_->stamp();
1716 if (stamps_[offset] < current_stamp) {
1717 stamps_[offset] = current_stamp;
1718 solver_->SaveValue(&bits_[offset]);
1719 }
1720 const int pos = BitPos64(val_offset);
1721 bits_[offset] &= ~OneBit64(pos);
1722 // Size.
1723 size_.Decr(solver_);
1724 // Holes.
1725 InitHoles();
1726 AddHole(val);
1727 return true;
1728 }
1729 uint64_t Size() const override { return size_.Value(); }
1730
1731 std::string DebugString() const override {
1732 std::string out;
1733 absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1734 for (int i = 0; i < bsize_; ++i) {
1735 absl::StrAppendFormat(&out, "%x", bits_[i]);
1736 }
1737 out += ")";
1738 return out;
1739 }
1740
1741 void DelayRemoveValue(int64_t val) override { removed_.push_back(val); }
1742
1743 void ApplyRemovedValues(DomainIntVar* var) override {
1744 std::sort(removed_.begin(), removed_.end());
1745 for (std::vector<int64_t>::iterator it = removed_.begin();
1746 it != removed_.end(); ++it) {
1747 var->RemoveValue(*it);
1748 }
1749 }
1750
1751 void ClearRemovedValues() override { removed_.clear(); }
1752
1753 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1754 std::string out;
1755 DCHECK(bit(min));
1756 DCHECK(bit(max));
1757 if (max != min) {
1758 int cumul = true;
1759 int64_t start_cumul = min;
1760 for (int64_t v = min + 1; v < max; ++v) {
1761 if (bit(v)) {
1762 if (!cumul) {
1763 cumul = true;
1764 start_cumul = v;
1765 }
1766 } else {
1767 if (cumul) {
1768 if (v == start_cumul + 1) {
1769 absl::StrAppendFormat(&out, "%d ", start_cumul);
1770 } else if (v == start_cumul + 2) {
1771 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1772 } else {
1773 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1774 }
1775 cumul = false;
1776 }
1777 }
1778 }
1779 if (cumul) {
1780 if (max == start_cumul + 1) {
1781 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1782 } else {
1783 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1784 }
1785 } else {
1786 absl::StrAppendFormat(&out, "%d", max);
1787 }
1788 } else {
1789 absl::StrAppendFormat(&out, "%d", min);
1790 }
1791 return out;
1792 }
1793
1794 DomainIntVar::BitSetIterator* MakeIterator() override {
1795 return new DomainIntVar::BitSetIterator(bits_, omin_);
1796 }
1797
1798 private:
1799 uint64_t* bits_;
1800 uint64_t* stamps_;
1801 const int64_t omin_;
1802 const int64_t omax_;
1803 NumericalRev<int64_t> size_;
1804 const int bsize_;
1805 std::vector<int64_t> removed_;
1806};
1807
1808// This is a special case where the bitset fits into one 64 bit integer.
1809// In that case, there are no offset to compute.
1810// Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1811class SmallBitSet : public DomainIntVar::BitSet {
1812 public:
1813 SmallBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1814 : BitSet(s),
1815 bits_(uint64_t{0}),
1816 stamp_(s->stamp() - 1),
1817 omin_(vmin),
1818 omax_(vmax),
1819 size_(vmax - vmin + 1) {
1820 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1821 bits_ = OneRange64(0, size_.Value() - 1);
1822 }
1823
1824 SmallBitSet(Solver* const s, absl::Span<const int64_t> sorted_values,
1825 int64_t vmin, int64_t vmax)
1826 : BitSet(s),
1827 bits_(uint64_t{0}),
1828 stamp_(s->stamp() - 1),
1829 omin_(vmin),
1830 omax_(vmax),
1831 size_(sorted_values.size()) {
1832 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1833 // We know the array is sorted and does not contains duplicate values.
1834 for (int i = 0; i < sorted_values.size(); ++i) {
1835 const int64_t val = sorted_values[i];
1836 DCHECK_GE(val, vmin);
1837 DCHECK_LE(val, vmax);
1838 DCHECK(!IsBitSet64(&bits_, val - omin_));
1839 bits_ |= OneBit64(val - omin_);
1840 }
1841 }
1842
1843 ~SmallBitSet() override {}
1844
1845 bool bit(int64_t val) const {
1846 DCHECK_GE(val, omin_);
1847 DCHECK_LE(val, omax_);
1848 return (bits_ & OneBit64(val - omin_)) != 0;
1849 }
1850
1851 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1852 DCHECK_GE(nmin, cmin);
1853 DCHECK_LE(nmin, cmax);
1854 DCHECK_LE(cmin, cmax);
1855 DCHECK_GE(cmin, omin_);
1856 DCHECK_LE(cmax, omax_);
1857 // We do not clean the bits between cmin and nmin.
1858 // But we use mask to look only at 'active' bits.
1859
1860 // Create the mask and compute new bits
1861 const uint64_t new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1862 if (new_bits != uint64_t{0}) {
1863 // Compute new size and new min
1864 size_.SetValue(solver_, BitCount64(new_bits));
1865 if (bit(nmin)) { // Common case, the new min is inside the bitset
1866 return nmin;
1867 }
1868 return LeastSignificantBitPosition64(new_bits) + omin_;
1869 } else { // == 0 -> Fail()
1870 solver_->Fail();
1871 return std::numeric_limits<int64_t>::max();
1872 }
1873 }
1874
1875 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1876 DCHECK_GE(nmax, cmin);
1877 DCHECK_LE(nmax, cmax);
1878 DCHECK_LE(cmin, cmax);
1879 DCHECK_GE(cmin, omin_);
1880 DCHECK_LE(cmax, omax_);
1881 // We do not clean the bits between nmax and cmax.
1882 // But we use mask to look only at 'active' bits.
1883
1884 // Create the mask and compute new_bits
1885 const uint64_t new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1886 if (new_bits != uint64_t{0}) {
1887 // Compute new size and new min
1888 size_.SetValue(solver_, BitCount64(new_bits));
1889 if (bit(nmax)) { // Common case, the new max is inside the bitset
1890 return nmax;
1891 }
1892 return MostSignificantBitPosition64(new_bits) + omin_;
1893 } else { // == 0 -> Fail()
1894 solver_->Fail();
1895 return std::numeric_limits<int64_t>::min();
1896 }
1897 }
1898
1899 bool SetValue(int64_t val) override {
1900 DCHECK_GE(val, omin_);
1901 DCHECK_LE(val, omax_);
1902 // We do not clean the bits. We will use masks to ignore the bits
1903 // that should have been cleaned.
1904 if (bit(val)) {
1905 size_.SetValue(solver_, 1);
1906 return true;
1907 }
1908 return false;
1909 }
1910
1911 bool Contains(int64_t val) const override {
1912 DCHECK_GE(val, omin_);
1913 DCHECK_LE(val, omax_);
1914 return bit(val);
1915 }
1916
1917 bool RemoveValue(int64_t val) override {
1918 DCHECK_GE(val, omin_);
1919 DCHECK_LE(val, omax_);
1920 if (bit(val)) {
1921 // Bitset.
1922 const uint64_t current_stamp = solver_->stamp();
1923 if (stamp_ < current_stamp) {
1924 stamp_ = current_stamp;
1925 solver_->SaveValue(&bits_);
1926 }
1927 bits_ &= ~OneBit64(val - omin_);
1928 DCHECK(!bit(val));
1929 // Size.
1930 size_.Decr(solver_);
1931 // Holes.
1932 InitHoles();
1933 AddHole(val);
1934 return true;
1935 } else {
1936 return false;
1937 }
1938 }
1939
1940 uint64_t Size() const override { return size_.Value(); }
1941
1942 std::string DebugString() const override {
1943 return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1944 }
1945
1946 void DelayRemoveValue(int64_t val) override {
1947 DCHECK_GE(val, omin_);
1948 DCHECK_LE(val, omax_);
1949 removed_.push_back(val);
1950 }
1951
1952 void ApplyRemovedValues(DomainIntVar* var) override {
1953 std::sort(removed_.begin(), removed_.end());
1954 for (std::vector<int64_t>::iterator it = removed_.begin();
1955 it != removed_.end(); ++it) {
1956 var->RemoveValue(*it);
1957 }
1958 }
1959
1960 void ClearRemovedValues() override { removed_.clear(); }
1961
1962 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1963 std::string out;
1964 DCHECK(bit(min));
1965 DCHECK(bit(max));
1966 if (max != min) {
1967 int cumul = true;
1968 int64_t start_cumul = min;
1969 for (int64_t v = min + 1; v < max; ++v) {
1970 if (bit(v)) {
1971 if (!cumul) {
1972 cumul = true;
1973 start_cumul = v;
1974 }
1975 } else {
1976 if (cumul) {
1977 if (v == start_cumul + 1) {
1978 absl::StrAppendFormat(&out, "%d ", start_cumul);
1979 } else if (v == start_cumul + 2) {
1980 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1981 } else {
1982 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1983 }
1984 cumul = false;
1985 }
1986 }
1987 }
1988 if (cumul) {
1989 if (max == start_cumul + 1) {
1990 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1991 } else {
1992 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1993 }
1994 } else {
1995 absl::StrAppendFormat(&out, "%d", max);
1996 }
1997 } else {
1998 absl::StrAppendFormat(&out, "%d", min);
1999 }
2000 return out;
2001 }
2002
2003 DomainIntVar::BitSetIterator* MakeIterator() override {
2004 return new DomainIntVar::BitSetIterator(&bits_, omin_);
2005 }
2006
2007 private:
2008 uint64_t bits_;
2009 uint64_t stamp_;
2010 const int64_t omin_;
2011 const int64_t omax_;
2012 NumericalRev<int64_t> size_;
2013 std::vector<int64_t> removed_;
2014};
2015
2016class EmptyIterator : public IntVarIterator {
2017 public:
2018 ~EmptyIterator() override {}
2019 void Init() override {}
2020 bool Ok() const override { return false; }
2021 int64_t Value() const override {
2022 LOG(FATAL) << "Should not be called";
2023 return 0LL;
2024 }
2025 void Next() override {}
2026};
2027
2028class RangeIterator : public IntVarIterator {
2029 public:
2030 explicit RangeIterator(const IntVar* const var)
2031 : var_(var),
2032 min_(std::numeric_limits<int64_t>::max()),
2033 max_(std::numeric_limits<int64_t>::min()),
2034 current_(-1) {}
2035
2036 ~RangeIterator() override {}
2037
2038 void Init() override {
2039 min_ = var_->Min();
2040 max_ = var_->Max();
2041 current_ = min_;
2042 }
2043
2044 bool Ok() const override { return current_ <= max_; }
2045
2046 int64_t Value() const override { return current_; }
2047
2048 void Next() override { current_++; }
2049
2050 private:
2051 const IntVar* const var_;
2052 int64_t min_;
2053 int64_t max_;
2054 int64_t current_;
2055};
2056
2057class DomainIntVarHoleIterator : public IntVarIterator {
2058 public:
2059 explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2060 : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2061
2062 ~DomainIntVarHoleIterator() override {}
2063
2064 void Init() override {
2065 bits_ = var_->bitset();
2066 if (bits_ != nullptr) {
2067 bits_->InitHoles();
2068 values_ = bits_->Holes().data();
2069 size_ = bits_->Holes().size();
2070 } else {
2071 values_ = nullptr;
2072 size_ = 0;
2073 }
2074 index_ = 0;
2075 }
2076
2077 bool Ok() const override { return index_ < size_; }
2078
2079 int64_t Value() const override {
2080 DCHECK(bits_ != nullptr);
2081 DCHECK(index_ < size_);
2082 return values_[index_];
2083 }
2084
2085 void Next() override { index_++; }
2086
2087 private:
2088 const DomainIntVar* const var_;
2089 DomainIntVar::BitSet* bits_;
2090 const int64_t* values_;
2091 int size_;
2092 int index_;
2093};
2094
2095class DomainIntVarDomainIterator : public IntVarIterator {
2096 public:
2097 explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2098 bool reversible)
2099 : var_(v),
2100 bitset_iterator_(nullptr),
2101 min_(std::numeric_limits<int64_t>::max()),
2102 max_(std::numeric_limits<int64_t>::min()),
2103 current_(-1),
2104 reversible_(reversible) {}
2105
2106 ~DomainIntVarDomainIterator() override {
2107 if (!reversible_ && bitset_iterator_) {
2108 delete bitset_iterator_;
2109 }
2110 }
2111
2112 void Init() override {
2113 if (var_->bitset() != nullptr && !var_->Bound()) {
2114 if (reversible_) {
2115 if (!bitset_iterator_) {
2116 Solver* const solver = var_->solver();
2117 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2118 bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2119 }
2120 } else {
2121 if (bitset_iterator_) {
2122 delete bitset_iterator_;
2123 }
2124 bitset_iterator_ = var_->bitset()->MakeIterator();
2125 }
2126 bitset_iterator_->Init(var_->Min(), var_->Max());
2127 } else {
2128 if (bitset_iterator_) {
2129 if (reversible_) {
2130 Solver* const solver = var_->solver();
2131 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2132 } else {
2133 delete bitset_iterator_;
2134 }
2135 bitset_iterator_ = nullptr;
2136 }
2137 min_ = var_->Min();
2138 max_ = var_->Max();
2139 current_ = min_;
2140 }
2141 }
2142
2143 bool Ok() const override {
2144 return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2145 }
2146
2147 int64_t Value() const override {
2148 return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2149 }
2150
2151 void Next() override {
2152 if (bitset_iterator_) {
2153 bitset_iterator_->Next();
2154 } else {
2155 current_++;
2156 }
2157 }
2158
2159 private:
2160 const DomainIntVar* const var_;
2161 DomainIntVar::BitSetIterator* bitset_iterator_;
2162 int64_t min_;
2163 int64_t max_;
2164 int64_t current_;
2165 const bool reversible_;
2166};
2167
2168class UnaryIterator : public IntVarIterator {
2169 public:
2170 UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2171 : iterator_(hole ? v->MakeHoleIterator(reversible)
2172 : v->MakeDomainIterator(reversible)),
2173 reversible_(reversible) {}
2174
2175 ~UnaryIterator() override {
2176 if (!reversible_) {
2177 delete iterator_;
2178 }
2179 }
2180
2181 void Init() override { iterator_->Init(); }
2182
2183 bool Ok() const override { return iterator_->Ok(); }
2184
2185 void Next() override { iterator_->Next(); }
2186
2187 protected:
2188 IntVarIterator* const iterator_;
2189 const bool reversible_;
2190};
2191
2192DomainIntVar::DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
2193 const std::string& name)
2194 : IntVar(s, name),
2195 min_(vmin),
2196 max_(vmax),
2197 old_min_(vmin),
2198 old_max_(vmax),
2199 new_min_(vmin),
2200 new_max_(vmax),
2201 handler_(this),
2202 in_process_(false),
2203 bits_(nullptr),
2204 value_watcher_(nullptr),
2205 bound_watcher_(nullptr) {}
2206
2207DomainIntVar::DomainIntVar(Solver* const s,
2208 absl::Span<const int64_t> sorted_values,
2209 const std::string& name)
2210 : IntVar(s, name),
2211 min_(std::numeric_limits<int64_t>::max()),
2212 max_(std::numeric_limits<int64_t>::min()),
2213 old_min_(std::numeric_limits<int64_t>::max()),
2214 old_max_(std::numeric_limits<int64_t>::min()),
2215 new_min_(std::numeric_limits<int64_t>::max()),
2216 new_max_(std::numeric_limits<int64_t>::min()),
2217 handler_(this),
2218 in_process_(false),
2219 bits_(nullptr),
2220 value_watcher_(nullptr),
2221 bound_watcher_(nullptr) {
2222 CHECK_GE(sorted_values.size(), 1);
2223 // We know that the vector is sorted and does not have duplicate values.
2224 const int64_t vmin = sorted_values.front();
2225 const int64_t vmax = sorted_values.back();
2226 const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2227
2228 min_.SetValue(solver(), vmin);
2229 old_min_ = vmin;
2230 new_min_ = vmin;
2231 max_.SetValue(solver(), vmax);
2232 old_max_ = vmax;
2233 new_max_ = vmax;
2234
2235 if (!contiguous) {
2236 if (vmax - vmin + 1 < 65) {
2237 bits_ = solver()->RevAlloc(
2238 new SmallBitSet(solver(), sorted_values, vmin, vmax));
2239 } else {
2240 bits_ = solver()->RevAlloc(
2241 new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2242 }
2243 }
2244}
2245
2246DomainIntVar::~DomainIntVar() {}
2247
2248void DomainIntVar::SetMin(int64_t m) {
2249 if (m <= min_.Value()) return;
2250 if (m > max_.Value()) solver()->Fail();
2251 if (in_process_) {
2252 if (m > new_min_) {
2253 new_min_ = m;
2254 if (new_min_ > new_max_) {
2255 solver()->Fail();
2256 }
2257 }
2258 } else {
2259 CheckOldMin();
2260 const int64_t new_min =
2261 (bits_ == nullptr
2262 ? m
2263 : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2264 min_.SetValue(solver(), new_min);
2265 if (min_.Value() > max_.Value()) {
2266 solver()->Fail();
2267 }
2268 Push();
2269 }
2270}
2271
2272void DomainIntVar::SetMax(int64_t m) {
2273 if (m >= max_.Value()) return;
2274 if (m < min_.Value()) solver()->Fail();
2275 if (in_process_) {
2276 if (m < new_max_) {
2277 new_max_ = m;
2278 if (new_max_ < new_min_) {
2279 solver()->Fail();
2280 }
2281 }
2282 } else {
2283 CheckOldMax();
2284 const int64_t new_max =
2285 (bits_ == nullptr
2286 ? m
2287 : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2288 max_.SetValue(solver(), new_max);
2289 if (min_.Value() > max_.Value()) {
2290 solver()->Fail();
2291 }
2292 Push();
2293 }
2294}
2295
2296void DomainIntVar::SetRange(int64_t mi, int64_t ma) {
2297 if (mi == ma) {
2298 SetValue(mi);
2299 } else {
2300 if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2301 if (mi <= min_.Value() && ma >= max_.Value()) return;
2302 if (in_process_) {
2303 if (ma < new_max_) {
2304 new_max_ = ma;
2305 }
2306 if (mi > new_min_) {
2307 new_min_ = mi;
2308 }
2309 if (new_min_ > new_max_) {
2310 solver()->Fail();
2311 }
2312 } else {
2313 if (mi > min_.Value()) {
2314 CheckOldMin();
2315 const int64_t new_min =
2316 (bits_ == nullptr
2317 ? mi
2318 : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2319 min_.SetValue(solver(), new_min);
2320 }
2321 if (min_.Value() > ma) {
2322 solver()->Fail();
2323 }
2324 if (ma < max_.Value()) {
2325 CheckOldMax();
2326 const int64_t new_max =
2327 (bits_ == nullptr
2328 ? ma
2329 : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2330 max_.SetValue(solver(), new_max);
2331 }
2332 if (min_.Value() > max_.Value()) {
2333 solver()->Fail();
2334 }
2335 Push();
2336 }
2337 }
2338}
2339
2340void DomainIntVar::SetValue(int64_t v) {
2341 if (v != min_.Value() || v != max_.Value()) {
2342 if (v < min_.Value() || v > max_.Value()) {
2343 solver()->Fail();
2344 }
2345 if (in_process_) {
2346 if (v > new_max_ || v < new_min_) {
2347 solver()->Fail();
2348 }
2349 new_min_ = v;
2350 new_max_ = v;
2351 } else {
2352 if (bits_ && !bits_->SetValue(v)) {
2353 solver()->Fail();
2354 }
2355 CheckOldMin();
2356 CheckOldMax();
2357 min_.SetValue(solver(), v);
2358 max_.SetValue(solver(), v);
2359 Push();
2360 }
2361 }
2362}
2363
2364void DomainIntVar::RemoveValue(int64_t v) {
2365 if (v < min_.Value() || v > max_.Value()) return;
2366 if (v == min_.Value()) {
2367 SetMin(v + 1);
2368 } else if (v == max_.Value()) {
2369 SetMax(v - 1);
2370 } else {
2371 if (bits_ == nullptr) {
2372 CreateBits();
2373 }
2374 if (in_process_) {
2375 if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2376 bits_->DelayRemoveValue(v);
2377 }
2378 } else {
2379 if (bits_->RemoveValue(v)) {
2380 Push();
2381 }
2382 }
2383 }
2384}
2385
2386void DomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2387 if (l <= min_.Value()) {
2388 SetMin(u + 1);
2389 } else if (u >= max_.Value()) {
2390 SetMax(l - 1);
2391 } else {
2392 for (int64_t v = l; v <= u; ++v) {
2393 RemoveValue(v);
2394 }
2395 }
2396}
2397
2398void DomainIntVar::CreateBits() {
2399 solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2400 if (max_.Value() - min_.Value() < 64) {
2401 bits_ = solver()->RevAlloc(
2402 new SmallBitSet(solver(), min_.Value(), max_.Value()));
2403 } else {
2404 bits_ = solver()->RevAlloc(
2405 new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2406 }
2407}
2408
2409void DomainIntVar::CleanInProcess() {
2410 in_process_ = false;
2411 if (bits_ != nullptr) {
2412 bits_->ClearHoles();
2413 }
2414}
2415
2416void DomainIntVar::Push() {
2417 const bool in_process = in_process_;
2418 EnqueueVar(&handler_);
2419 CHECK_EQ(in_process, in_process_);
2420}
2421
2422void DomainIntVar::Process() {
2423 CHECK(!in_process_);
2424 in_process_ = true;
2425 if (bits_ != nullptr) {
2426 bits_->ClearRemovedValues();
2427 }
2428 set_variable_to_clean_on_fail(this);
2429 new_min_ = min_.Value();
2430 new_max_ = max_.Value();
2431 const bool is_bound = min_.Value() == max_.Value();
2432 const bool range_changed =
2433 min_.Value() != OldMin() || max_.Value() != OldMax();
2434 // Process immediate demons.
2435 if (is_bound) {
2436 ExecuteAll(bound_demons_);
2437 }
2438 if (range_changed) {
2439 ExecuteAll(range_demons_);
2440 }
2441 ExecuteAll(domain_demons_);
2442
2443 // Process delayed demons.
2444 if (is_bound) {
2445 EnqueueAll(delayed_bound_demons_);
2446 }
2447 if (range_changed) {
2448 EnqueueAll(delayed_range_demons_);
2449 }
2450 EnqueueAll(delayed_domain_demons_);
2451
2452 // Everything went well if we arrive here. Let's clean the variable.
2453 set_variable_to_clean_on_fail(nullptr);
2454 CleanInProcess();
2455 old_min_ = min_.Value();
2456 old_max_ = max_.Value();
2457 if (min_.Value() < new_min_) {
2458 SetMin(new_min_);
2459 }
2460 if (max_.Value() > new_max_) {
2461 SetMax(new_max_);
2462 }
2463 if (bits_ != nullptr) {
2464 bits_->ApplyRemovedValues(this);
2465 }
2466}
2467
2468template <typename T>
2469T* CondRevAlloc(Solver* solver, bool reversible, T* object) {
2470 return reversible ? solver->RevAlloc(object) : object;
2471}
2472
2473IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2474 return CondRevAlloc(solver(), reversible, new DomainIntVarHoleIterator(this));
2475}
2476
2477IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2478 return CondRevAlloc(solver(), reversible,
2479 new DomainIntVarDomainIterator(this, reversible));
2480}
2481
2482std::string DomainIntVar::DebugString() const {
2483 std::string out;
2484 const std::string& var_name = name();
2485 if (!var_name.empty()) {
2486 out = var_name + "(";
2487 } else {
2488 out = "DomainIntVar(";
2489 }
2490 if (min_.Value() == max_.Value()) {
2491 absl::StrAppendFormat(&out, "%d", min_.Value());
2492 } else if (bits_ != nullptr) {
2493 out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2494 } else {
2495 absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2496 }
2497 out += ")";
2498 return out;
2499}
2500
2501// ----- Real Boolean Var -----
2502
2503class ConcreteBooleanVar : public BooleanVar {
2504 public:
2505 // Utility classes
2506 class Handler : public Demon {
2507 public:
2508 explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2509 ~Handler() override {}
2510 void Run(Solver* const s) override {
2511 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2512 var_->Process();
2513 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2514 }
2515 Solver::DemonPriority priority() const override {
2516 return Solver::VAR_PRIORITY;
2517 }
2518 std::string DebugString() const override {
2519 return absl::StrFormat("Handler(%s)", var_->DebugString());
2520 }
2521
2522 private:
2523 ConcreteBooleanVar* const var_;
2524 };
2525
2526 ConcreteBooleanVar(Solver* const s, const std::string& name)
2527 : BooleanVar(s, name), handler_(this) {}
2528
2529 ~ConcreteBooleanVar() override {}
2530
2531 void SetValue(int64_t v) override {
2532 if (value_ == kUnboundBooleanVarValue) {
2533 if ((v & 0xfffffffffffffffe) == 0) {
2534 InternalSaveBooleanVarValue(solver(), this);
2535 value_ = static_cast<int>(v);
2536 EnqueueVar(&handler_);
2537 return;
2538 }
2539 } else if (v == value_) {
2540 return;
2541 }
2542 solver()->Fail();
2543 }
2544
2545 void Process() {
2546 DCHECK_NE(value_, kUnboundBooleanVarValue);
2547 ExecuteAll(bound_demons_);
2548 for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2549 ++it) {
2550 EnqueueDelayedDemon(*it);
2551 }
2552 }
2553
2554 int64_t OldMin() const override { return 0LL; }
2555 int64_t OldMax() const override { return 1LL; }
2556 void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2557
2558 private:
2559 Handler handler_;
2560};
2561
2562// ----- IntConst -----
2563
2564class IntConst : public IntVar {
2565 public:
2566 IntConst(Solver* const s, int64_t value, const std::string& name = "")
2567 : IntVar(s, name), value_(value) {}
2568 ~IntConst() override {}
2569
2570 int64_t Min() const override { return value_; }
2571 void SetMin(int64_t m) override {
2572 if (m > value_) {
2573 solver()->Fail();
2574 }
2575 }
2576 int64_t Max() const override { return value_; }
2577 void SetMax(int64_t m) override {
2578 if (m < value_) {
2579 solver()->Fail();
2580 }
2581 }
2582 void SetRange(int64_t l, int64_t u) override {
2583 if (l > value_ || u < value_) {
2584 solver()->Fail();
2585 }
2586 }
2587 void SetValue(int64_t v) override {
2588 if (v != value_) {
2589 solver()->Fail();
2590 }
2591 }
2592 bool Bound() const override { return true; }
2593 int64_t Value() const override { return value_; }
2594 void RemoveValue(int64_t v) override {
2595 if (v == value_) {
2596 solver()->Fail();
2597 }
2598 }
2599 void RemoveInterval(int64_t l, int64_t u) override {
2600 if (l <= value_ && value_ <= u) {
2601 solver()->Fail();
2602 }
2603 }
2604 void WhenBound(Demon* d) override {}
2605 void WhenRange(Demon* d) override {}
2606 void WhenDomain(Demon* d) override {}
2607 uint64_t Size() const override { return 1; }
2608 bool Contains(int64_t v) const override { return (v == value_); }
2609 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2610 return CondRevAlloc(solver(), reversible, new EmptyIterator());
2611 }
2612 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2613 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
2614 }
2615 int64_t OldMin() const override { return value_; }
2616 int64_t OldMax() const override { return value_; }
2617 std::string DebugString() const override {
2618 std::string out;
2619 if (solver()->HasName(this)) {
2620 const std::string& var_name = name();
2621 absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2622 } else {
2623 absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2624 }
2625 return out;
2626 }
2627
2628 int VarType() const override { return CONST_VAR; }
2629
2630 IntVar* IsEqual(int64_t constant) override {
2631 if (constant == value_) {
2632 return solver()->MakeIntConst(1);
2633 } else {
2634 return solver()->MakeIntConst(0);
2635 }
2636 }
2637
2638 IntVar* IsDifferent(int64_t constant) override {
2639 if (constant == value_) {
2640 return solver()->MakeIntConst(0);
2641 } else {
2642 return solver()->MakeIntConst(1);
2643 }
2644 }
2645
2646 IntVar* IsGreaterOrEqual(int64_t constant) override {
2647 return solver()->MakeIntConst(value_ >= constant);
2648 }
2649
2650 IntVar* IsLessOrEqual(int64_t constant) override {
2651 return solver()->MakeIntConst(value_ <= constant);
2652 }
2653
2654 std::string name() const override {
2655 if (solver()->HasName(this)) {
2656 return PropagationBaseObject::name();
2657 } else {
2658 return absl::StrCat(value_);
2659 }
2660 }
2661
2662 private:
2663 int64_t value_;
2664};
2665
2666// ----- x + c variable, optimized case -----
2667
2668class PlusCstVar : public IntVar {
2669 public:
2670 PlusCstVar(Solver* const s, IntVar* v, int64_t c)
2671 : IntVar(s), var_(v), cst_(c) {}
2672
2673 ~PlusCstVar() override {}
2674
2675 void WhenRange(Demon* d) override { var_->WhenRange(d); }
2676
2677 void WhenBound(Demon* d) override { var_->WhenBound(d); }
2678
2679 void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2680
2681 int64_t OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2682
2683 int64_t OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2684
2685 std::string DebugString() const override {
2686 if (HasName()) {
2687 return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2688 } else {
2689 return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2690 }
2691 }
2692
2693 int VarType() const override { return VAR_ADD_CST; }
2694
2695 void Accept(ModelVisitor* const visitor) const override {
2696 visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2697 var_);
2698 }
2699
2700 IntVar* IsEqual(int64_t constant) override {
2701 return var_->IsEqual(constant - cst_);
2702 }
2703
2704 IntVar* IsDifferent(int64_t constant) override {
2705 return var_->IsDifferent(constant - cst_);
2706 }
2707
2708 IntVar* IsGreaterOrEqual(int64_t constant) override {
2709 return var_->IsGreaterOrEqual(constant - cst_);
2710 }
2711
2712 IntVar* IsLessOrEqual(int64_t constant) override {
2713 return var_->IsLessOrEqual(constant - cst_);
2714 }
2715
2716 IntVar* SubVar() const { return var_; }
2717
2718 int64_t Constant() const { return cst_; }
2719
2720 protected:
2721 IntVar* const var_;
2722 const int64_t cst_;
2723};
2724
2725class PlusCstIntVar : public PlusCstVar {
2726 public:
2727 class PlusCstIntVarIterator : public UnaryIterator {
2728 public:
2729 PlusCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2730 : UnaryIterator(v, hole, rev), cst_(c) {}
2731
2732 ~PlusCstIntVarIterator() override {}
2733
2734 int64_t Value() const override { return iterator_->Value() + cst_; }
2735
2736 private:
2737 const int64_t cst_;
2738 };
2739
2740 PlusCstIntVar(Solver* const s, IntVar* v, int64_t c) : PlusCstVar(s, v, c) {}
2741
2742 ~PlusCstIntVar() override {}
2743
2744 int64_t Min() const override { return var_->Min() + cst_; }
2745
2746 void SetMin(int64_t m) override { var_->SetMin(CapSub(m, cst_)); }
2747
2748 int64_t Max() const override { return var_->Max() + cst_; }
2749
2750 void SetMax(int64_t m) override { var_->SetMax(CapSub(m, cst_)); }
2751
2752 void SetRange(int64_t l, int64_t u) override {
2753 var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2754 }
2755
2756 void SetValue(int64_t v) override { var_->SetValue(v - cst_); }
2757
2758 int64_t Value() const override { return var_->Value() + cst_; }
2759
2760 bool Bound() const override { return var_->Bound(); }
2761
2762 void RemoveValue(int64_t v) override { var_->RemoveValue(v - cst_); }
2763
2764 void RemoveInterval(int64_t l, int64_t u) override {
2765 var_->RemoveInterval(l - cst_, u - cst_);
2766 }
2767
2768 uint64_t Size() const override { return var_->Size(); }
2769
2770 bool Contains(int64_t v) const override { return var_->Contains(v - cst_); }
2771
2772 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2773 return CondRevAlloc(
2774 solver(), reversible,
2775 new PlusCstIntVarIterator(var_, cst_, true, reversible));
2776 }
2777 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2778 return CondRevAlloc(
2779 solver(), reversible,
2780 new PlusCstIntVarIterator(var_, cst_, false, reversible));
2781 }
2782};
2783
2784class PlusCstDomainIntVar : public PlusCstVar {
2785 public:
2786 class PlusCstDomainIntVarIterator : public UnaryIterator {
2787 public:
2788 PlusCstDomainIntVarIterator(const IntVar* const v, int64_t c, bool hole,
2789 bool reversible)
2790 : UnaryIterator(v, hole, reversible), cst_(c) {}
2791
2792 ~PlusCstDomainIntVarIterator() override {}
2793
2794 int64_t Value() const override { return iterator_->Value() + cst_; }
2795
2796 private:
2797 const int64_t cst_;
2798 };
2799
2800 PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64_t c)
2801 : PlusCstVar(s, v, c) {}
2802
2803 ~PlusCstDomainIntVar() override {}
2804
2805 int64_t Min() const override;
2806 void SetMin(int64_t m) override;
2807 int64_t Max() const override;
2808 void SetMax(int64_t m) override;
2809 void SetRange(int64_t l, int64_t u) override;
2810 void SetValue(int64_t v) override;
2811 bool Bound() const override;
2812 int64_t Value() const override;
2813 void RemoveValue(int64_t v) override;
2814 void RemoveInterval(int64_t l, int64_t u) override;
2815 uint64_t Size() const override;
2816 bool Contains(int64_t v) const override;
2817
2818 DomainIntVar* domain_int_var() const {
2819 return reinterpret_cast<DomainIntVar*>(var_);
2820 }
2821
2822 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2823 return CondRevAlloc(
2824 solver(), reversible,
2825 new PlusCstDomainIntVarIterator(var_, cst_, true, reversible));
2826 }
2827 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2828 return CondRevAlloc(
2829 solver(), reversible,
2830 new PlusCstDomainIntVarIterator(var_, cst_, false, reversible));
2831 }
2832};
2833
2834int64_t PlusCstDomainIntVar::Min() const {
2835 return domain_int_var()->min_.Value() + cst_;
2836}
2837
2838void PlusCstDomainIntVar::SetMin(int64_t m) {
2839 domain_int_var()->DomainIntVar::SetMin(CapSub(m, cst_));
2840}
2841
2842int64_t PlusCstDomainIntVar::Max() const {
2843 return domain_int_var()->max_.Value() + cst_;
2844}
2845
2846void PlusCstDomainIntVar::SetMax(int64_t m) {
2847 domain_int_var()->DomainIntVar::SetMax(CapSub(m, cst_));
2848}
2849
2850void PlusCstDomainIntVar::SetRange(int64_t l, int64_t u) {
2851 domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2852}
2853
2854void PlusCstDomainIntVar::SetValue(int64_t v) {
2855 domain_int_var()->DomainIntVar::SetValue(v - cst_);
2856}
2857
2858bool PlusCstDomainIntVar::Bound() const {
2859 return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2860}
2861
2862int64_t PlusCstDomainIntVar::Value() const {
2863 CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2864 << " variable is not bound";
2865 return domain_int_var()->min_.Value() + cst_;
2866}
2867
2868void PlusCstDomainIntVar::RemoveValue(int64_t v) {
2869 domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2870}
2871
2872void PlusCstDomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2873 domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2874}
2875
2876uint64_t PlusCstDomainIntVar::Size() const {
2877 return domain_int_var()->DomainIntVar::Size();
2878}
2879
2880bool PlusCstDomainIntVar::Contains(int64_t v) const {
2881 return domain_int_var()->DomainIntVar::Contains(v - cst_);
2882}
2883
2884// c - x variable, optimized case
2885
2886class SubCstIntVar : public IntVar {
2887 public:
2888 class SubCstIntVarIterator : public UnaryIterator {
2889 public:
2890 SubCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2891 : UnaryIterator(v, hole, rev), cst_(c) {}
2892 ~SubCstIntVarIterator() override {}
2893
2894 int64_t Value() const override { return cst_ - iterator_->Value(); }
2895
2896 private:
2897 const int64_t cst_;
2898 };
2899
2900 SubCstIntVar(Solver* s, IntVar* v, int64_t c);
2901 ~SubCstIntVar() override;
2902
2903 int64_t Min() const override;
2904 void SetMin(int64_t m) override;
2905 int64_t Max() const override;
2906 void SetMax(int64_t m) override;
2907 void SetRange(int64_t l, int64_t u) override;
2908 void SetValue(int64_t v) override;
2909 bool Bound() const override;
2910 int64_t Value() const override;
2911 void RemoveValue(int64_t v) override;
2912 void RemoveInterval(int64_t l, int64_t u) override;
2913 uint64_t Size() const override;
2914 bool Contains(int64_t v) const override;
2915 void WhenRange(Demon* d) override;
2916 void WhenBound(Demon* d) override;
2917 void WhenDomain(Demon* d) override;
2918 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2919 return CondRevAlloc(solver(), reversible,
2920 new SubCstIntVarIterator(var_, cst_, true, reversible));
2921 }
2922 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2923 return CondRevAlloc(
2924 solver(), reversible,
2925 new SubCstIntVarIterator(var_, cst_, false, reversible));
2926 }
2927 int64_t OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2928 int64_t OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2929 std::string DebugString() const override;
2930 std::string name() const override;
2931 int VarType() const override { return CST_SUB_VAR; }
2932
2933 void Accept(ModelVisitor* const visitor) const override {
2934 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2935 cst_, var_);
2936 }
2937
2938 IntVar* IsEqual(int64_t constant) override {
2939 return var_->IsEqual(cst_ - constant);
2940 }
2941
2942 IntVar* IsDifferent(int64_t constant) override {
2943 return var_->IsDifferent(cst_ - constant);
2944 }
2945
2946 IntVar* IsGreaterOrEqual(int64_t constant) override {
2947 return var_->IsLessOrEqual(cst_ - constant);
2948 }
2949
2950 IntVar* IsLessOrEqual(int64_t constant) override {
2951 return var_->IsGreaterOrEqual(cst_ - constant);
2952 }
2953
2954 IntVar* SubVar() const { return var_; }
2955 int64_t Constant() const { return cst_; }
2956
2957 private:
2958 IntVar* const var_;
2959 const int64_t cst_;
2960};
2961
2962SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64_t c)
2963 : IntVar(s), var_(v), cst_(c) {}
2964
2965SubCstIntVar::~SubCstIntVar() {}
2966
2967int64_t SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2968
2969void SubCstIntVar::SetMin(int64_t m) { var_->SetMax(CapSub(cst_, m)); }
2970
2971int64_t SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2972
2973void SubCstIntVar::SetMax(int64_t m) { var_->SetMin(CapSub(cst_, m)); }
2974
2975void SubCstIntVar::SetRange(int64_t l, int64_t u) {
2976 var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2977}
2978
2979void SubCstIntVar::SetValue(int64_t v) { var_->SetValue(cst_ - v); }
2980
2981bool SubCstIntVar::Bound() const { return var_->Bound(); }
2982
2983void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2984
2985int64_t SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2986
2987void SubCstIntVar::RemoveValue(int64_t v) { var_->RemoveValue(cst_ - v); }
2988
2989void SubCstIntVar::RemoveInterval(int64_t l, int64_t u) {
2990 var_->RemoveInterval(cst_ - u, cst_ - l);
2991}
2992
2993void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2994
2995void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2996
2997uint64_t SubCstIntVar::Size() const { return var_->Size(); }
2998
2999bool SubCstIntVar::Contains(int64_t v) const {
3000 return var_->Contains(cst_ - v);
3001}
3002
3003std::string SubCstIntVar::DebugString() const {
3004 if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3005 return absl::StrFormat("Not(%s)", var_->DebugString());
3006 } else {
3007 return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
3008 }
3009}
3010
3011std::string SubCstIntVar::name() const {
3012 if (solver()->HasName(this)) {
3013 return PropagationBaseObject::name();
3014 } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3015 return absl::StrFormat("Not(%s)", var_->name());
3016 } else {
3017 return absl::StrFormat("(%d - %s)", cst_, var_->name());
3018 }
3019}
3020
3021// -x variable, optimized case
3022
3023class OppIntVar : public IntVar {
3024 public:
3025 class OppIntVarIterator : public UnaryIterator {
3026 public:
3027 OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3028 : UnaryIterator(v, hole, reversible) {}
3029 ~OppIntVarIterator() override {}
3030
3031 int64_t Value() const override { return -iterator_->Value(); }
3032 };
3033
3034 OppIntVar(Solver* s, IntVar* v);
3035 ~OppIntVar() override;
3036
3037 int64_t Min() const override;
3038 void SetMin(int64_t m) override;
3039 int64_t Max() const override;
3040 void SetMax(int64_t m) override;
3041 void SetRange(int64_t l, int64_t u) override;
3042 void SetValue(int64_t v) override;
3043 bool Bound() const override;
3044 int64_t Value() const override;
3045 void RemoveValue(int64_t v) override;
3046 void RemoveInterval(int64_t l, int64_t u) override;
3047 uint64_t Size() const override;
3048 bool Contains(int64_t v) const override;
3049 void WhenRange(Demon* d) override;
3050 void WhenBound(Demon* d) override;
3051 void WhenDomain(Demon* d) override;
3052 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3053 return CondRevAlloc(solver(), reversible,
3054 new OppIntVarIterator(var_, true, reversible));
3055 }
3056 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3057 return CondRevAlloc(solver(), reversible,
3058 new OppIntVarIterator(var_, false, reversible));
3059 }
3060 int64_t OldMin() const override { return CapOpp(var_->OldMax()); }
3061 int64_t OldMax() const override { return CapOpp(var_->OldMin()); }
3062 std::string DebugString() const override;
3063 int VarType() const override { return OPP_VAR; }
3064
3065 void Accept(ModelVisitor* const visitor) const override {
3066 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3067 var_);
3068 }
3069
3070 IntVar* IsEqual(int64_t constant) override {
3071 return var_->IsEqual(-constant);
3072 }
3073
3074 IntVar* IsDifferent(int64_t constant) override {
3075 return var_->IsDifferent(-constant);
3076 }
3077
3078 IntVar* IsGreaterOrEqual(int64_t constant) override {
3079 return var_->IsLessOrEqual(-constant);
3080 }
3081
3082 IntVar* IsLessOrEqual(int64_t constant) override {
3083 return var_->IsGreaterOrEqual(-constant);
3084 }
3085
3086 IntVar* SubVar() const { return var_; }
3087
3088 private:
3089 IntVar* const var_;
3090};
3091
3092OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3093
3094OppIntVar::~OppIntVar() {}
3095
3096int64_t OppIntVar::Min() const { return -var_->Max(); }
3097
3098void OppIntVar::SetMin(int64_t m) { var_->SetMax(CapOpp(m)); }
3099
3100int64_t OppIntVar::Max() const { return -var_->Min(); }
3101
3102void OppIntVar::SetMax(int64_t m) { var_->SetMin(CapOpp(m)); }
3103
3104void OppIntVar::SetRange(int64_t l, int64_t u) {
3105 var_->SetRange(CapOpp(u), CapOpp(l));
3106}
3107
3108void OppIntVar::SetValue(int64_t v) { var_->SetValue(CapOpp(v)); }
3109
3110bool OppIntVar::Bound() const { return var_->Bound(); }
3111
3112void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3113
3114int64_t OppIntVar::Value() const { return -var_->Value(); }
3115
3116void OppIntVar::RemoveValue(int64_t v) { var_->RemoveValue(-v); }
3117
3118void OppIntVar::RemoveInterval(int64_t l, int64_t u) {
3119 var_->RemoveInterval(-u, -l);
3120}
3121
3122void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3123
3124void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3125
3126uint64_t OppIntVar::Size() const { return var_->Size(); }
3127
3128bool OppIntVar::Contains(int64_t v) const { return var_->Contains(-v); }
3129
3130std::string OppIntVar::DebugString() const {
3131 return absl::StrFormat("-(%s)", var_->DebugString());
3132}
3133
3134// ----- Utility functions -----
3135
3136// x * c variable, optimized case
3137
3138class TimesCstIntVar : public IntVar {
3139 public:
3140 TimesCstIntVar(Solver* const s, IntVar* v, int64_t c)
3141 : IntVar(s), var_(v), cst_(c) {}
3142 ~TimesCstIntVar() override {}
3143
3144 IntVar* SubVar() const { return var_; }
3145 int64_t Constant() const { return cst_; }
3146
3147 void Accept(ModelVisitor* const visitor) const override {
3148 visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3149 var_);
3150 }
3151
3152 IntVar* IsEqual(int64_t constant) override {
3153 if (constant % cst_ == 0) {
3154 return var_->IsEqual(constant / cst_);
3155 } else {
3156 return solver()->MakeIntConst(0);
3157 }
3158 }
3159
3160 IntVar* IsDifferent(int64_t constant) override {
3161 if (constant % cst_ == 0) {
3162 return var_->IsDifferent(constant / cst_);
3163 } else {
3164 return solver()->MakeIntConst(1);
3165 }
3166 }
3167
3168 IntVar* IsGreaterOrEqual(int64_t constant) override {
3169 if (cst_ > 0) {
3170 return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3171 } else {
3172 return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3173 }
3174 }
3175
3176 IntVar* IsLessOrEqual(int64_t constant) override {
3177 if (cst_ > 0) {
3178 return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3179 } else {
3180 return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3181 }
3182 }
3183
3184 std::string DebugString() const override {
3185 return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3186 }
3187
3188 int VarType() const override { return VAR_TIMES_CST; }
3189
3190 protected:
3191 IntVar* const var_;
3192 const int64_t cst_;
3193};
3194
3195class TimesPosCstIntVar : public TimesCstIntVar {
3196 public:
3197 class TimesPosCstIntVarIterator : public UnaryIterator {
3198 public:
3199 TimesPosCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3200 bool reversible)
3201 : UnaryIterator(v, hole, reversible), cst_(c) {}
3202 ~TimesPosCstIntVarIterator() override {}
3203
3204 int64_t Value() const override { return iterator_->Value() * cst_; }
3205
3206 private:
3207 const int64_t cst_;
3208 };
3209
3210 TimesPosCstIntVar(Solver* s, IntVar* v, int64_t c);
3211 ~TimesPosCstIntVar() override;
3212
3213 int64_t Min() const override;
3214 void SetMin(int64_t m) override;
3215 int64_t Max() const override;
3216 void SetMax(int64_t m) override;
3217 void SetRange(int64_t l, int64_t u) override;
3218 void SetValue(int64_t v) override;
3219 bool Bound() const override;
3220 int64_t Value() const override;
3221 void RemoveValue(int64_t v) override;
3222 void RemoveInterval(int64_t l, int64_t u) override;
3223 uint64_t Size() const override;
3224 bool Contains(int64_t v) const override;
3225 void WhenRange(Demon* d) override;
3226 void WhenBound(Demon* d) override;
3227 void WhenDomain(Demon* d) override;
3228 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3229 return CondRevAlloc(
3230 solver(), reversible,
3231 new TimesPosCstIntVarIterator(var_, cst_, true, reversible));
3232 }
3233 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3234 return CondRevAlloc(
3235 solver(), reversible,
3236 new TimesPosCstIntVarIterator(var_, cst_, false, reversible));
3237 }
3238 int64_t OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3239 int64_t OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3240};
3241
3242// ----- TimesPosCstIntVar -----
3243
3244TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c)
3245 : TimesCstIntVar(s, v, c) {}
3246
3247TimesPosCstIntVar::~TimesPosCstIntVar() {}
3248
3249int64_t TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3250
3251void TimesPosCstIntVar::SetMin(int64_t m) {
3252 if (m != std::numeric_limits<int64_t>::min()) {
3253 var_->SetMin(PosIntDivUp(m, cst_));
3254 }
3255}
3256
3257int64_t TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3258
3259void TimesPosCstIntVar::SetMax(int64_t m) {
3260 if (m != std::numeric_limits<int64_t>::max()) {
3261 var_->SetMax(PosIntDivDown(m, cst_));
3262 }
3263}
3264
3265void TimesPosCstIntVar::SetRange(int64_t l, int64_t u) {
3266 var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3267}
3268
3269void TimesPosCstIntVar::SetValue(int64_t v) {
3270 if (v % cst_ != 0) {
3271 solver()->Fail();
3272 }
3273 var_->SetValue(v / cst_);
3274}
3275
3276bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3277
3278void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3279
3280int64_t TimesPosCstIntVar::Value() const {
3281 return CapProd(var_->Value(), cst_);
3282}
3283
3284void TimesPosCstIntVar::RemoveValue(int64_t v) {
3285 if (v % cst_ == 0) {
3286 var_->RemoveValue(v / cst_);
3287 }
3288}
3289
3290void TimesPosCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3291 for (int64_t v = l; v <= u; ++v) {
3292 RemoveValue(v);
3293 }
3294 // TODO(user) : Improve me
3295}
3296
3297void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3298
3299void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3300
3301uint64_t TimesPosCstIntVar::Size() const { return var_->Size(); }
3302
3303bool TimesPosCstIntVar::Contains(int64_t v) const {
3304 return (v % cst_ == 0 && var_->Contains(v / cst_));
3305}
3306
3307// b * c variable, optimized case
3308
3309class TimesPosCstBoolVar : public TimesCstIntVar {
3310 public:
3311 class TimesPosCstBoolVarIterator : public UnaryIterator {
3312 public:
3313 // TODO(user) : optimize this.
3314 TimesPosCstBoolVarIterator(const IntVar* const v, int64_t c, bool hole,
3315 bool reversible)
3316 : UnaryIterator(v, hole, reversible), cst_(c) {}
3317 ~TimesPosCstBoolVarIterator() override {}
3318
3319 int64_t Value() const override { return iterator_->Value() * cst_; }
3320
3321 private:
3322 const int64_t cst_;
3323 };
3324
3325 TimesPosCstBoolVar(Solver* s, BooleanVar* v, int64_t c);
3326 ~TimesPosCstBoolVar() override;
3327
3328 int64_t Min() const override;
3329 void SetMin(int64_t m) override;
3330 int64_t Max() const override;
3331 void SetMax(int64_t m) override;
3332 void SetRange(int64_t l, int64_t u) override;
3333 void SetValue(int64_t v) override;
3334 bool Bound() const override;
3335 int64_t Value() const override;
3336 void RemoveValue(int64_t v) override;
3337 void RemoveInterval(int64_t l, int64_t u) override;
3338 uint64_t Size() const override;
3339 bool Contains(int64_t v) const override;
3340 void WhenRange(Demon* d) override;
3341 void WhenBound(Demon* d) override;
3342 void WhenDomain(Demon* d) override;
3343 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3344 return CondRevAlloc(solver(), reversible, new EmptyIterator());
3345 }
3346 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3347 return CondRevAlloc(
3348 solver(), reversible,
3349 new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3350 }
3351 int64_t OldMin() const override { return 0; }
3352 int64_t OldMax() const override { return cst_; }
3353
3354 BooleanVar* boolean_var() const {
3355 return reinterpret_cast<BooleanVar*>(var_);
3356 }
3357};
3358
3359// ----- TimesPosCstBoolVar -----
3360
3361TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v,
3362 int64_t c)
3363 : TimesCstIntVar(s, v, c) {}
3364
3365TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3366
3367int64_t TimesPosCstBoolVar::Min() const {
3368 return (boolean_var()->RawValue() == 1) * cst_;
3369}
3370
3371void TimesPosCstBoolVar::SetMin(int64_t m) {
3372 if (m > cst_) {
3373 solver()->Fail();
3374 } else if (m > 0) {
3375 boolean_var()->SetMin(1);
3376 }
3377}
3378
3379int64_t TimesPosCstBoolVar::Max() const {
3380 return (boolean_var()->RawValue() != 0) * cst_;
3381}
3382
3383void TimesPosCstBoolVar::SetMax(int64_t m) {
3384 if (m < 0) {
3385 solver()->Fail();
3386 } else if (m < cst_) {
3387 boolean_var()->SetMax(0);
3388 }
3389}
3390
3391void TimesPosCstBoolVar::SetRange(int64_t l, int64_t u) {
3392 if (u < 0 || l > cst_ || l > u) {
3393 solver()->Fail();
3394 }
3395 if (l > 0) {
3396 boolean_var()->SetMin(1);
3397 } else if (u < cst_) {
3398 boolean_var()->SetMax(0);
3399 }
3400}
3401
3402void TimesPosCstBoolVar::SetValue(int64_t v) {
3403 if (v == 0) {
3404 boolean_var()->SetValue(0);
3405 } else if (v == cst_) {
3406 boolean_var()->SetValue(1);
3407 } else {
3408 solver()->Fail();
3409 }
3410}
3411
3412bool TimesPosCstBoolVar::Bound() const {
3413 return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3414}
3415
3416void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3417
3418int64_t TimesPosCstBoolVar::Value() const {
3419 CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3420 << " variable is not bound";
3421 return boolean_var()->RawValue() * cst_;
3422}
3423
3424void TimesPosCstBoolVar::RemoveValue(int64_t v) {
3425 if (v == 0) {
3426 boolean_var()->RemoveValue(0);
3427 } else if (v == cst_) {
3428 boolean_var()->RemoveValue(1);
3429 }
3430}
3431
3432void TimesPosCstBoolVar::RemoveInterval(int64_t l, int64_t u) {
3433 if (l <= 0 && u >= 0) {
3434 boolean_var()->RemoveValue(0);
3435 }
3436 if (l <= cst_ && u >= cst_) {
3437 boolean_var()->RemoveValue(1);
3438 }
3439}
3440
3441void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3442
3443void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3444
3445uint64_t TimesPosCstBoolVar::Size() const {
3446 return (1 +
3447 (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3448}
3449
3450bool TimesPosCstBoolVar::Contains(int64_t v) const {
3451 if (v == 0) {
3452 return boolean_var()->RawValue() != 1;
3453 } else if (v == cst_) {
3454 return boolean_var()->RawValue() != 0;
3455 }
3456 return false;
3457}
3458
3459// TimesNegCstIntVar
3460
3461class TimesNegCstIntVar : public TimesCstIntVar {
3462 public:
3463 class TimesNegCstIntVarIterator : public UnaryIterator {
3464 public:
3465 TimesNegCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3466 bool reversible)
3467 : UnaryIterator(v, hole, reversible), cst_(c) {}
3468 ~TimesNegCstIntVarIterator() override {}
3469
3470 int64_t Value() const override { return iterator_->Value() * cst_; }
3471
3472 private:
3473 const int64_t cst_;
3474 };
3475
3476 TimesNegCstIntVar(Solver* s, IntVar* v, int64_t c);
3477 ~TimesNegCstIntVar() override;
3478
3479 int64_t Min() const override;
3480 void SetMin(int64_t m) override;
3481 int64_t Max() const override;
3482 void SetMax(int64_t m) override;
3483 void SetRange(int64_t l, int64_t u) override;
3484 void SetValue(int64_t v) override;
3485 bool Bound() const override;
3486 int64_t Value() const override;
3487 void RemoveValue(int64_t v) override;
3488 void RemoveInterval(int64_t l, int64_t u) override;
3489 uint64_t Size() const override;
3490 bool Contains(int64_t v) const override;
3491 void WhenRange(Demon* d) override;
3492 void WhenBound(Demon* d) override;
3493 void WhenDomain(Demon* d) override;
3494 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3495 return CondRevAlloc(
3496 solver(), reversible,
3497 new TimesNegCstIntVarIterator(var_, cst_, true, reversible));
3498 }
3499 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3500 return CondRevAlloc(
3501 solver(), reversible,
3502 new TimesNegCstIntVarIterator(var_, cst_, false, reversible));
3503 }
3504 int64_t OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3505 int64_t OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3506};
3507
3508// ----- TimesNegCstIntVar -----
3509
3510TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c)
3511 : TimesCstIntVar(s, v, c) {}
3512
3513TimesNegCstIntVar::~TimesNegCstIntVar() {}
3514
3515int64_t TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3516
3517void TimesNegCstIntVar::SetMin(int64_t m) {
3518 if (m != std::numeric_limits<int64_t>::min()) {
3519 var_->SetMax(PosIntDivDown(-m, -cst_));
3520 }
3521}
3522
3523int64_t TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3524
3525void TimesNegCstIntVar::SetMax(int64_t m) {
3526 if (m != std::numeric_limits<int64_t>::max()) {
3527 var_->SetMin(PosIntDivUp(-m, -cst_));
3528 }
3529}
3530
3531void TimesNegCstIntVar::SetRange(int64_t l, int64_t u) {
3532 var_->SetRange(PosIntDivUp(CapOpp(u), CapOpp(cst_)),
3533 PosIntDivDown(CapOpp(l), CapOpp(cst_)));
3534}
3535
3536void TimesNegCstIntVar::SetValue(int64_t v) {
3537 if (v % cst_ != 0) {
3538 solver()->Fail();
3539 }
3540 var_->SetValue(v / cst_);
3541}
3542
3543bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3544
3545void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3546
3547int64_t TimesNegCstIntVar::Value() const {
3548 return CapProd(var_->Value(), cst_);
3549}
3550
3551void TimesNegCstIntVar::RemoveValue(int64_t v) {
3552 if (v % cst_ == 0) {
3553 var_->RemoveValue(v / cst_);
3554 }
3555}
3556
3557void TimesNegCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3558 for (int64_t v = l; v <= u; ++v) {
3559 RemoveValue(v);
3560 }
3561 // TODO(user) : Improve me
3562}
3563
3564void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3565
3566void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3567
3568uint64_t TimesNegCstIntVar::Size() const { return var_->Size(); }
3569
3570bool TimesNegCstIntVar::Contains(int64_t v) const {
3571 return (v % cst_ == 0 && var_->Contains(v / cst_));
3572}
3573
3574// ---------- arithmetic expressions ----------
3575
3576// ----- PlusIntExpr -----
3577
3578class PlusIntExpr : public BaseIntExpr {
3579 public:
3580 PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3581 : BaseIntExpr(s), left_(l), right_(r) {}
3582
3583 ~PlusIntExpr() override {}
3584
3585 int64_t Min() const override { return left_->Min() + right_->Min(); }
3586
3587 void SetMin(int64_t m) override {
3588 if (m > left_->Min() + right_->Min()) {
3589 // Catching potential overflow.
3590 if (m > right_->Max() + left_->Max()) solver()->Fail();
3591 left_->SetMin(m - right_->Max());
3592 right_->SetMin(m - left_->Max());
3593 }
3594 }
3595
3596 void SetRange(int64_t l, int64_t u) override {
3597 const int64_t left_min = left_->Min();
3598 const int64_t right_min = right_->Min();
3599 const int64_t left_max = left_->Max();
3600 const int64_t right_max = right_->Max();
3601 if (l > left_min + right_min) {
3602 // Catching potential overflow.
3603 if (l > right_max + left_max) solver()->Fail();
3604 left_->SetMin(l - right_max);
3605 right_->SetMin(l - left_max);
3606 }
3607 if (u < left_max + right_max) {
3608 // Catching potential overflow.
3609 if (u < right_min + left_min) solver()->Fail();
3610 left_->SetMax(u - right_min);
3611 right_->SetMax(u - left_min);
3612 }
3613 }
3614
3615 int64_t Max() const override { return left_->Max() + right_->Max(); }
3616
3617 void SetMax(int64_t m) override {
3618 if (m < left_->Max() + right_->Max()) {
3619 // Catching potential overflow.
3620 if (m < right_->Min() + left_->Min()) solver()->Fail();
3621 left_->SetMax(m - right_->Min());
3622 right_->SetMax(m - left_->Min());
3623 }
3624 }
3625
3626 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3627
3628 void Range(int64_t* const mi, int64_t* const ma) override {
3629 *mi = left_->Min() + right_->Min();
3630 *ma = left_->Max() + right_->Max();
3631 }
3632
3633 std::string name() const override {
3634 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3635 }
3636
3637 std::string DebugString() const override {
3638 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3639 right_->DebugString());
3640 }
3641
3642 void WhenRange(Demon* d) override {
3643 left_->WhenRange(d);
3644 right_->WhenRange(d);
3645 }
3646
3647 void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3648 PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3649 if (casted != nullptr) {
3650 ExpandPlusIntExpr(casted->left_, subs);
3651 ExpandPlusIntExpr(casted->right_, subs);
3652 } else {
3653 subs->push_back(expr);
3654 }
3655 }
3656
3657 IntVar* CastToVar() override {
3658 if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3659 dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3660 std::vector<IntExpr*> sub_exprs;
3661 ExpandPlusIntExpr(left_, &sub_exprs);
3662 ExpandPlusIntExpr(right_, &sub_exprs);
3663 if (sub_exprs.size() >= 3) {
3664 std::vector<IntVar*> sub_vars(sub_exprs.size());
3665 for (int i = 0; i < sub_exprs.size(); ++i) {
3666 sub_vars[i] = sub_exprs[i]->Var();
3667 }
3668 return solver()->MakeSum(sub_vars)->Var();
3669 }
3670 }
3671 return BaseIntExpr::CastToVar();
3672 }
3673
3674 void Accept(ModelVisitor* const visitor) const override {
3675 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3676 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3677 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3678 right_);
3679 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3680 }
3681
3682 private:
3683 IntExpr* const left_;
3684 IntExpr* const right_;
3685};
3686
3687class SafePlusIntExpr : public BaseIntExpr {
3688 public:
3689 SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3690 : BaseIntExpr(s), left_(l), right_(r) {}
3691
3692 ~SafePlusIntExpr() override {}
3693
3694 int64_t Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3695
3696 void SetMin(int64_t m) override {
3697 left_->SetMin(CapSub(m, right_->Max()));
3698 right_->SetMin(CapSub(m, left_->Max()));
3699 }
3700
3701 void SetRange(int64_t l, int64_t u) override {
3702 const int64_t left_min = left_->Min();
3703 const int64_t right_min = right_->Min();
3704 const int64_t left_max = left_->Max();
3705 const int64_t right_max = right_->Max();
3706 if (l > CapAdd(left_min, right_min)) {
3707 left_->SetMin(CapSub(l, right_max));
3708 right_->SetMin(CapSub(l, left_max));
3709 }
3710 if (u < CapAdd(left_max, right_max)) {
3711 left_->SetMax(CapSub(u, right_min));
3712 right_->SetMax(CapSub(u, left_min));
3713 }
3714 }
3715
3716 int64_t Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3717
3718 void SetMax(int64_t m) override {
3719 left_->SetMax(CapSub(m, right_->Min()));
3720 right_->SetMax(CapSub(m, left_->Min()));
3721 }
3722
3723 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3724
3725 std::string name() const override {
3726 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3727 }
3728
3729 std::string DebugString() const override {
3730 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3731 right_->DebugString());
3732 }
3733
3734 void WhenRange(Demon* d) override {
3735 left_->WhenRange(d);
3736 right_->WhenRange(d);
3737 }
3738
3739 void Accept(ModelVisitor* const visitor) const override {
3740 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3741 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3742 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3743 right_);
3744 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3745 }
3746
3747 private:
3748 IntExpr* const left_;
3749 IntExpr* const right_;
3750};
3751
3752// ----- PlusIntCstExpr -----
3753
3754class PlusIntCstExpr : public BaseIntExpr {
3755 public:
3756 PlusIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3757 : BaseIntExpr(s), expr_(e), value_(v) {}
3758 ~PlusIntCstExpr() override {}
3759 int64_t Min() const override { return CapAdd(expr_->Min(), value_); }
3760 void SetMin(int64_t m) override { expr_->SetMin(CapSub(m, value_)); }
3761 int64_t Max() const override { return CapAdd(expr_->Max(), value_); }
3762 void SetMax(int64_t m) override { expr_->SetMax(CapSub(m, value_)); }
3763 bool Bound() const override { return (expr_->Bound()); }
3764 std::string name() const override {
3765 return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3766 }
3767 std::string DebugString() const override {
3768 return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3769 }
3770 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3771 IntVar* CastToVar() override;
3772 void Accept(ModelVisitor* const visitor) const override {
3773 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3774 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3775 expr_);
3776 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3777 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3778 }
3779
3780 private:
3781 IntExpr* const expr_;
3782 const int64_t value_;
3783};
3784
3785IntVar* PlusIntCstExpr::CastToVar() {
3786 Solver* const s = solver();
3787 IntVar* const var = expr_->Var();
3788 IntVar* cast = nullptr;
3789 if (AddOverflows(value_, expr_->Max()) ||
3790 AddOverflows(value_, expr_->Min())) {
3791 return BaseIntExpr::CastToVar();
3792 }
3793 switch (var->VarType()) {
3794 case DOMAIN_INT_VAR:
3795 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3796 s, reinterpret_cast<DomainIntVar*>(var), value_)));
3797 // FIXME: Break was inserted during fallthrough cleanup. Please check.
3798 break;
3799 default:
3800 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3801 break;
3802 }
3803 return cast;
3804}
3805
3806// ----- SubIntExpr -----
3807
3808class SubIntExpr : public BaseIntExpr {
3809 public:
3810 SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3811 : BaseIntExpr(s), left_(l), right_(r) {}
3812
3813 ~SubIntExpr() override {}
3814
3815 int64_t Min() const override { return left_->Min() - right_->Max(); }
3816
3817 void SetMin(int64_t m) override {
3818 left_->SetMin(CapAdd(m, right_->Min()));
3819 right_->SetMax(CapSub(left_->Max(), m));
3820 }
3821
3822 int64_t Max() const override { return left_->Max() - right_->Min(); }
3823
3824 void SetMax(int64_t m) override {
3825 left_->SetMax(CapAdd(m, right_->Max()));
3826 right_->SetMin(CapSub(left_->Min(), m));
3827 }
3828
3829 void Range(int64_t* mi, int64_t* ma) override {
3830 *mi = left_->Min() - right_->Max();
3831 *ma = left_->Max() - right_->Min();
3832 }
3833
3834 void SetRange(int64_t l, int64_t u) override {
3835 const int64_t left_min = left_->Min();
3836 const int64_t right_min = right_->Min();
3837 const int64_t left_max = left_->Max();
3838 const int64_t right_max = right_->Max();
3839 if (l > left_min - right_max) {
3840 left_->SetMin(CapAdd(l, right_min));
3841 right_->SetMax(CapSub(left_max, l));
3842 }
3843 if (u < left_max - right_min) {
3844 left_->SetMax(CapAdd(u, right_max));
3845 right_->SetMin(CapSub(left_min, u));
3846 }
3847 }
3848
3849 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3850
3851 std::string name() const override {
3852 return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3853 }
3854
3855 std::string DebugString() const override {
3856 return absl::StrFormat("(%s - %s)", left_->DebugString(),
3857 right_->DebugString());
3858 }
3859
3860 void WhenRange(Demon* d) override {
3861 left_->WhenRange(d);
3862 right_->WhenRange(d);
3863 }
3864
3865 void Accept(ModelVisitor* const visitor) const override {
3866 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3867 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3868 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3869 right_);
3870 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3871 }
3872
3873 IntExpr* left() const { return left_; }
3874 IntExpr* right() const { return right_; }
3875
3876 protected:
3877 IntExpr* const left_;
3878 IntExpr* const right_;
3879};
3880
3881class SafeSubIntExpr : public SubIntExpr {
3882 public:
3883 SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3884 : SubIntExpr(s, l, r) {}
3885
3886 ~SafeSubIntExpr() override {}
3887
3888 int64_t Min() const override { return CapSub(left_->Min(), right_->Max()); }
3889
3890 void SetMin(int64_t m) override {
3891 left_->SetMin(CapAdd(m, right_->Min()));
3892 right_->SetMax(CapSub(left_->Max(), m));
3893 }
3894
3895 void SetRange(int64_t l, int64_t u) override {
3896 const int64_t left_min = left_->Min();
3897 const int64_t right_min = right_->Min();
3898 const int64_t left_max = left_->Max();
3899 const int64_t right_max = right_->Max();
3900 if (l > CapSub(left_min, right_max)) {
3901 left_->SetMin(CapAdd(l, right_min));
3902 right_->SetMax(CapSub(left_max, l));
3903 }
3904 if (u < CapSub(left_max, right_min)) {
3905 left_->SetMax(CapAdd(u, right_max));
3906 right_->SetMin(CapSub(left_min, u));
3907 }
3908 }
3909
3910 void Range(int64_t* mi, int64_t* ma) override {
3911 *mi = CapSub(left_->Min(), right_->Max());
3912 *ma = CapSub(left_->Max(), right_->Min());
3913 }
3914
3915 int64_t Max() const override { return CapSub(left_->Max(), right_->Min()); }
3916
3917 void SetMax(int64_t m) override {
3918 left_->SetMax(CapAdd(m, right_->Max()));
3919 right_->SetMin(CapSub(left_->Min(), m));
3920 }
3921};
3922
3923// l - r
3924
3925// ----- SubIntCstExpr -----
3926
3927class SubIntCstExpr : public BaseIntExpr {
3928 public:
3929 SubIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3930 : BaseIntExpr(s), expr_(e), value_(v) {}
3931 ~SubIntCstExpr() override {}
3932 int64_t Min() const override { return CapSub(value_, expr_->Max()); }
3933 void SetMin(int64_t m) override { expr_->SetMax(CapSub(value_, m)); }
3934 int64_t Max() const override { return CapSub(value_, expr_->Min()); }
3935 void SetMax(int64_t m) override { expr_->SetMin(CapSub(value_, m)); }
3936 bool Bound() const override { return (expr_->Bound()); }
3937 std::string name() const override {
3938 return absl::StrFormat("(%d - %s)", value_, expr_->name());
3939 }
3940 std::string DebugString() const override {
3941 return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3942 }
3943 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3944 IntVar* CastToVar() override;
3945
3946 void Accept(ModelVisitor* const visitor) const override {
3947 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3948 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3949 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3950 expr_);
3951 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3952 }
3953
3954 private:
3955 IntExpr* const expr_;
3956 const int64_t value_;
3957};
3958
3959IntVar* SubIntCstExpr::CastToVar() {
3960 if (SubOverflows(value_, expr_->Min()) ||
3961 SubOverflows(value_, expr_->Max())) {
3962 return BaseIntExpr::CastToVar();
3963 }
3964 Solver* const s = solver();
3965 IntVar* const var =
3966 s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3967 return var;
3968}
3969
3970// ----- OppIntExpr -----
3971
3972class OppIntExpr : public BaseIntExpr {
3973 public:
3974 OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3975 ~OppIntExpr() override {}
3976 int64_t Min() const override { return (CapOpp(expr_->Max())); }
3977 void SetMin(int64_t m) override { expr_->SetMax(CapOpp(m)); }
3978 int64_t Max() const override { return (CapOpp(expr_->Min())); }
3979 void SetMax(int64_t m) override { expr_->SetMin(CapOpp(m)); }
3980 bool Bound() const override { return (expr_->Bound()); }
3981 std::string name() const override {
3982 return absl::StrFormat("(-%s)", expr_->name());
3983 }
3984 std::string DebugString() const override {
3985 return absl::StrFormat("(-%s)", expr_->DebugString());
3986 }
3987 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3988 IntVar* CastToVar() override;
3989
3990 void Accept(ModelVisitor* const visitor) const override {
3991 visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3992 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3993 expr_);
3994 visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3995 }
3996
3997 private:
3998 IntExpr* const expr_;
3999};
4000
4001IntVar* OppIntExpr::CastToVar() {
4002 Solver* const s = solver();
4003 IntVar* const var =
4004 s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
4005 return var;
4006}
4007
4008// ----- TimesIntCstExpr -----
4009
4010class TimesIntCstExpr : public BaseIntExpr {
4011 public:
4012 TimesIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4013 : BaseIntExpr(s), expr_(e), value_(v) {}
4014
4015 ~TimesIntCstExpr() override {}
4016
4017 bool Bound() const override { return (expr_->Bound()); }
4018
4019 std::string name() const override {
4020 return absl::StrFormat("(%s * %d)", expr_->name(), value_);
4021 }
4022
4023 std::string DebugString() const override {
4024 return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
4025 }
4026
4027 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4028
4029 IntExpr* Expr() const { return expr_; }
4030
4031 int64_t Constant() const { return value_; }
4032
4033 void Accept(ModelVisitor* const visitor) const override {
4034 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4035 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4036 expr_);
4037 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4038 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4039 }
4040
4041 protected:
4042 IntExpr* const expr_;
4043 const int64_t value_;
4044};
4045
4046// ----- TimesPosIntCstExpr -----
4047
4048class TimesPosIntCstExpr : public TimesIntCstExpr {
4049 public:
4050 TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4051 : TimesIntCstExpr(s, e, v) {
4052 CHECK_GT(v, 0);
4053 }
4054
4055 ~TimesPosIntCstExpr() override {}
4056
4057 int64_t Min() const override { return expr_->Min() * value_; }
4058
4059 void SetMin(int64_t m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4060
4061 int64_t Max() const override { return expr_->Max() * value_; }
4062
4063 void SetMax(int64_t m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4064
4065 IntVar* CastToVar() override {
4066 Solver* const s = solver();
4067 IntVar* var = nullptr;
4068 if (expr_->IsVar() &&
4069 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4070 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4071 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4072 } else {
4073 var = s->RegisterIntVar(
4074 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4075 }
4076 return var;
4077 }
4078};
4079
4080// This expressions adds safe arithmetic (w.r.t. overflows) compared
4081// to the previous one.
4082class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4083 public:
4084 SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4085 : TimesIntCstExpr(s, e, v) {
4086 CHECK_GT(v, 0);
4087 }
4088
4089 ~SafeTimesPosIntCstExpr() override {}
4090
4091 int64_t Min() const override { return CapProd(expr_->Min(), value_); }
4092
4093 void SetMin(int64_t m) override {
4094 if (m != std::numeric_limits<int64_t>::min()) {
4095 expr_->SetMin(PosIntDivUp(m, value_));
4096 }
4097 }
4098
4099 int64_t Max() const override { return CapProd(expr_->Max(), value_); }
4100
4101 void SetMax(int64_t m) override {
4102 if (m != std::numeric_limits<int64_t>::max()) {
4103 expr_->SetMax(PosIntDivDown(m, value_));
4104 }
4105 }
4106
4107 IntVar* CastToVar() override {
4108 Solver* const s = solver();
4109 IntVar* var = nullptr;
4110 if (expr_->IsVar() &&
4111 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4112 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4113 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4114 } else {
4115 // TODO(user): Check overflows.
4116 var = s->RegisterIntVar(
4117 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4118 }
4119 return var;
4120 }
4121};
4122
4123// ----- TimesIntNegCstExpr -----
4124
4125class TimesIntNegCstExpr : public TimesIntCstExpr {
4126 public:
4127 TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4128 : TimesIntCstExpr(s, e, v) {
4129 CHECK_LT(v, 0);
4130 }
4131
4132 ~TimesIntNegCstExpr() override {}
4133
4134 int64_t Min() const override { return CapProd(expr_->Max(), value_); }
4135
4136 void SetMin(int64_t m) override {
4137 if (m != std::numeric_limits<int64_t>::min()) {
4138 expr_->SetMax(PosIntDivDown(-m, -value_));
4139 }
4140 }
4141
4142 int64_t Max() const override { return CapProd(expr_->Min(), value_); }
4143
4144 void SetMax(int64_t m) override {
4145 if (m != std::numeric_limits<int64_t>::max()) {
4146 expr_->SetMin(PosIntDivUp(-m, -value_));
4147 }
4148 }
4149
4150 IntVar* CastToVar() override {
4151 Solver* const s = solver();
4152 IntVar* var = nullptr;
4153 var = s->RegisterIntVar(
4154 s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4155 return var;
4156 }
4157};
4158
4159// ----- Utilities for product expression -----
4160
4161// Propagates set_min on left * right, left and right >= 0.
4162void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4163 DCHECK_GE(left->Min(), 0);
4164 DCHECK_GE(right->Min(), 0);
4165 const int64_t lmax = left->Max();
4166 const int64_t rmax = right->Max();
4167 if (m > CapProd(lmax, rmax)) {
4168 left->solver()->Fail();
4169 }
4170 if (m > CapProd(left->Min(), right->Min())) {
4171 // Ok for m == 0 due to left and right being positive
4172 if (0 != rmax) {
4173 left->SetMin(PosIntDivUp(m, rmax));
4174 }
4175 if (0 != lmax) {
4176 right->SetMin(PosIntDivUp(m, lmax));
4177 }
4178 }
4179}
4180
4181// Propagates set_max on left * right, left and right >= 0.
4182void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4183 DCHECK_GE(left->Min(), 0);
4184 DCHECK_GE(right->Min(), 0);
4185 const int64_t lmin = left->Min();
4186 const int64_t rmin = right->Min();
4187 if (m < CapProd(lmin, rmin)) {
4188 left->solver()->Fail();
4189 }
4190 if (m < CapProd(left->Max(), right->Max())) {
4191 if (0 != lmin) {
4192 right->SetMax(PosIntDivDown(m, lmin));
4193 }
4194 if (0 != rmin) {
4195 left->SetMax(PosIntDivDown(m, rmin));
4196 }
4197 // else do nothing: 0 is supporting any value from other expr.
4198 }
4199}
4200
4201// Propagates set_min on left * right, left >= 0, right across 0.
4202void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4203 DCHECK_GE(left->Min(), 0);
4204 DCHECK_GT(right->Max(), 0);
4205 DCHECK_LT(right->Min(), 0);
4206 const int64_t lmax = left->Max();
4207 const int64_t rmax = right->Max();
4208 if (m > CapProd(lmax, rmax)) {
4209 left->solver()->Fail();
4210 }
4211 if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4212 DCHECK_EQ(0, left->Min());
4213 DCHECK_LE(m, 0);
4214 } else {
4215 if (m > 0) { // We deduce right > 0.
4216 left->SetMin(PosIntDivUp(m, rmax));
4217 right->SetMin(PosIntDivUp(m, lmax));
4218 } else if (m == 0) {
4219 const int64_t lmin = left->Min();
4220 if (lmin > 0) {
4221 right->SetMin(0);
4222 }
4223 } else { // m < 0
4224 const int64_t lmin = left->Min();
4225 if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4226 right->SetMin(-PosIntDivDown(-m, lmin));
4227 }
4228 }
4229 }
4230}
4231
4232// Propagates set_min on left * right, left and right across 0.
4233void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4234 DCHECK_LT(left->Min(), 0);
4235 DCHECK_GT(left->Max(), 0);
4236 DCHECK_GT(right->Max(), 0);
4237 DCHECK_LT(right->Min(), 0);
4238 const int64_t lmin = left->Min();
4239 const int64_t lmax = left->Max();
4240 const int64_t rmin = right->Min();
4241 const int64_t rmax = right->Max();
4242 if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4243 left->solver()->Fail();
4244 }
4245 if (m >
4246 CapProd(lmin, rmin)) { // Must be positive section * positive section.
4247 left->SetMin(PosIntDivUp(m, rmax));
4248 right->SetMin(PosIntDivUp(m, lmax));
4249 } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4250 left->SetMax(CapOpp(PosIntDivUp(m, CapOpp(rmin))));
4251 right->SetMax(CapOpp(PosIntDivUp(m, CapOpp(lmin))));
4252 }
4253}
4254
4255void TimesSetMin(IntExpr* const left, IntExpr* const right,
4256 IntExpr* const minus_left, IntExpr* const minus_right,
4257 int64_t m) {
4258 if (left->Min() >= 0) {
4259 if (right->Min() >= 0) {
4260 SetPosPosMinExpr(left, right, m);
4261 } else if (right->Max() <= 0) {
4262 SetPosPosMaxExpr(left, minus_right, -m);
4263 } else { // right->Min() < 0 && right->Max() > 0
4264 SetPosGenMinExpr(left, right, m);
4265 }
4266 } else if (left->Max() <= 0) {
4267 if (right->Min() >= 0) {
4268 SetPosPosMaxExpr(right, minus_left, -m);
4269 } else if (right->Max() <= 0) {
4270 SetPosPosMinExpr(minus_left, minus_right, m);
4271 } else { // right->Min() < 0 && right->Max() > 0
4272 SetPosGenMinExpr(minus_left, minus_right, m);
4273 }
4274 } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4275 SetPosGenMinExpr(right, left, m);
4276 } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4277 SetPosGenMinExpr(minus_right, minus_left, m);
4278 } else { // left->Min() < 0 && left->Max() > 0 &&
4279 // right->Min() < 0 && right->Max() > 0
4280 SetGenGenMinExpr(left, right, m);
4281 }
4282}
4283
4284class TimesIntExpr : public BaseIntExpr {
4285 public:
4286 TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4287 : BaseIntExpr(s),
4288 left_(l),
4289 right_(r),
4290 minus_left_(s->MakeOpposite(left_)),
4291 minus_right_(s->MakeOpposite(right_)) {}
4292 ~TimesIntExpr() override {}
4293 int64_t Min() const override {
4294 const int64_t lmin = left_->Min();
4295 const int64_t lmax = left_->Max();
4296 const int64_t rmin = right_->Min();
4297 const int64_t rmax = right_->Max();
4298 return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4299 std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4300 }
4301 void SetMin(int64_t m) override;
4302 int64_t Max() const override {
4303 const int64_t lmin = left_->Min();
4304 const int64_t lmax = left_->Max();
4305 const int64_t rmin = right_->Min();
4306 const int64_t rmax = right_->Max();
4307 return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4308 std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4309 }
4310 void SetMax(int64_t m) override;
4311 bool Bound() const override;
4312 std::string name() const override {
4313 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4314 }
4315 std::string DebugString() const override {
4316 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4317 right_->DebugString());
4318 }
4319 void WhenRange(Demon* d) override {
4320 left_->WhenRange(d);
4321 right_->WhenRange(d);
4322 }
4323
4324 void Accept(ModelVisitor* const visitor) const override {
4325 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4326 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4327 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4328 right_);
4329 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4330 }
4331
4332 private:
4333 IntExpr* const left_;
4334 IntExpr* const right_;
4335 IntExpr* const minus_left_;
4336 IntExpr* const minus_right_;
4337};
4338
4339void TimesIntExpr::SetMin(int64_t m) {
4340 if (m != std::numeric_limits<int64_t>::min()) {
4341 TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4342 }
4343}
4344
4345void TimesIntExpr::SetMax(int64_t m) {
4346 if (m != std::numeric_limits<int64_t>::max()) {
4347 TimesSetMin(left_, minus_right_, minus_left_, right_, CapOpp(m));
4348 }
4349}
4350
4351bool TimesIntExpr::Bound() const {
4352 const bool left_bound = left_->Bound();
4353 const bool right_bound = right_->Bound();
4354 return ((left_bound && left_->Max() == 0) ||
4355 (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4356}
4357
4358// ----- TimesPosIntExpr -----
4359
4360class TimesPosIntExpr : public BaseIntExpr {
4361 public:
4362 TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4363 : BaseIntExpr(s), left_(l), right_(r) {}
4364 ~TimesPosIntExpr() override {}
4365 int64_t Min() const override { return (left_->Min() * right_->Min()); }
4366 void SetMin(int64_t m) override;
4367 int64_t Max() const override { return (left_->Max() * right_->Max()); }
4368 void SetMax(int64_t m) override;
4369 bool Bound() const override;
4370 std::string name() const override {
4371 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4372 }
4373 std::string DebugString() const override {
4374 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4375 right_->DebugString());
4376 }
4377 void WhenRange(Demon* d) override {
4378 left_->WhenRange(d);
4379 right_->WhenRange(d);
4380 }
4381
4382 void Accept(ModelVisitor* const visitor) const override {
4383 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4384 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4385 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4386 right_);
4387 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4388 }
4389
4390 private:
4391 IntExpr* const left_;
4392 IntExpr* const right_;
4393};
4394
4395void TimesPosIntExpr::SetMin(int64_t m) { SetPosPosMinExpr(left_, right_, m); }
4396
4397void TimesPosIntExpr::SetMax(int64_t m) { SetPosPosMaxExpr(left_, right_, m); }
4398
4399bool TimesPosIntExpr::Bound() const {
4400 return (left_->Max() == 0 || right_->Max() == 0 ||
4401 (left_->Bound() && right_->Bound()));
4402}
4403
4404// ----- SafeTimesPosIntExpr -----
4405
4406class SafeTimesPosIntExpr : public BaseIntExpr {
4407 public:
4408 SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4409 : BaseIntExpr(s), left_(l), right_(r) {}
4410 ~SafeTimesPosIntExpr() override {}
4411 int64_t Min() const override { return CapProd(left_->Min(), right_->Min()); }
4412 void SetMin(int64_t m) override {
4413 if (m != std::numeric_limits<int64_t>::min()) {
4414 SetPosPosMinExpr(left_, right_, m);
4415 }
4416 }
4417 int64_t Max() const override { return CapProd(left_->Max(), right_->Max()); }
4418 void SetMax(int64_t m) override {
4419 if (m != std::numeric_limits<int64_t>::max()) {
4420 SetPosPosMaxExpr(left_, right_, m);
4421 }
4422 }
4423 bool Bound() const override {
4424 return (left_->Max() == 0 || right_->Max() == 0 ||
4425 (left_->Bound() && right_->Bound()));
4426 }
4427 std::string name() const override {
4428 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4429 }
4430 std::string DebugString() const override {
4431 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4432 right_->DebugString());
4433 }
4434 void WhenRange(Demon* d) override {
4435 left_->WhenRange(d);
4436 right_->WhenRange(d);
4437 }
4438
4439 void Accept(ModelVisitor* const visitor) const override {
4440 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4441 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4442 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4443 right_);
4444 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4445 }
4446
4447 private:
4448 IntExpr* const left_;
4449 IntExpr* const right_;
4450};
4451
4452// ----- TimesBooleanPosIntExpr -----
4453
4454class TimesBooleanPosIntExpr : public BaseIntExpr {
4455 public:
4456 TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4457 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4458 ~TimesBooleanPosIntExpr() override {}
4459 int64_t Min() const override {
4460 return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4461 }
4462 void SetMin(int64_t m) override;
4463 int64_t Max() const override {
4464 return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4465 }
4466 void SetMax(int64_t m) override;
4467 void Range(int64_t* mi, int64_t* ma) override;
4468 void SetRange(int64_t mi, int64_t ma) override;
4469 bool Bound() const override;
4470 std::string name() const override {
4471 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4472 }
4473 std::string DebugString() const override {
4474 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4475 expr_->DebugString());
4476 }
4477 void WhenRange(Demon* d) override {
4478 boolvar_->WhenRange(d);
4479 expr_->WhenRange(d);
4480 }
4481
4482 void Accept(ModelVisitor* const visitor) const override {
4483 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4484 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4485 boolvar_);
4486 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4487 expr_);
4488 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4489 }
4490
4491 private:
4492 BooleanVar* const boolvar_;
4493 IntExpr* const expr_;
4494};
4495
4496void TimesBooleanPosIntExpr::SetMin(int64_t m) {
4497 if (m > 0) {
4498 boolvar_->SetValue(1);
4499 expr_->SetMin(m);
4500 }
4501}
4502
4503void TimesBooleanPosIntExpr::SetMax(int64_t m) {
4504 if (m < 0) {
4505 solver()->Fail();
4506 }
4507 if (m < expr_->Min()) {
4508 boolvar_->SetValue(0);
4509 }
4510 if (boolvar_->RawValue() == 1) {
4511 expr_->SetMax(m);
4512 }
4513}
4514
4515void TimesBooleanPosIntExpr::Range(int64_t* mi, int64_t* ma) {
4516 const int value = boolvar_->RawValue();
4517 if (value == 0) {
4518 *mi = 0;
4519 *ma = 0;
4520 } else if (value == 1) {
4521 expr_->Range(mi, ma);
4522 } else {
4523 *mi = 0;
4524 *ma = expr_->Max();
4525 }
4526}
4527
4528void TimesBooleanPosIntExpr::SetRange(int64_t mi, int64_t ma) {
4529 if (ma < 0 || mi > ma) {
4530 solver()->Fail();
4531 }
4532 if (mi > 0) {
4533 boolvar_->SetValue(1);
4534 expr_->SetMin(mi);
4535 }
4536 if (ma < expr_->Min()) {
4537 boolvar_->SetValue(0);
4538 }
4539 if (boolvar_->RawValue() == 1) {
4540 expr_->SetMax(ma);
4541 }
4542}
4543
4544bool TimesBooleanPosIntExpr::Bound() const {
4545 return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4546 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4547 expr_->Bound()));
4548}
4549
4550// ----- TimesBooleanIntExpr -----
4551
4552class TimesBooleanIntExpr : public BaseIntExpr {
4553 public:
4554 TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4555 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4556 ~TimesBooleanIntExpr() override {}
4557 int64_t Min() const override {
4558 switch (boolvar_->RawValue()) {
4559 case 0: {
4560 return 0LL;
4561 }
4562 case 1: {
4563 return expr_->Min();
4564 }
4565 default: {
4566 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4567 return std::min(int64_t{0}, expr_->Min());
4568 }
4569 }
4570 }
4571 void SetMin(int64_t m) override;
4572 int64_t Max() const override {
4573 switch (boolvar_->RawValue()) {
4574 case 0: {
4575 return 0LL;
4576 }
4577 case 1: {
4578 return expr_->Max();
4579 }
4580 default: {
4581 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4582 return std::max(int64_t{0}, expr_->Max());
4583 }
4584 }
4585 }
4586 void SetMax(int64_t m) override;
4587 void Range(int64_t* mi, int64_t* ma) override;
4588 void SetRange(int64_t mi, int64_t ma) override;
4589 bool Bound() const override;
4590 std::string name() const override {
4591 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4592 }
4593 std::string DebugString() const override {
4594 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4595 expr_->DebugString());
4596 }
4597 void WhenRange(Demon* d) override {
4598 boolvar_->WhenRange(d);
4599 expr_->WhenRange(d);
4600 }
4601
4602 void Accept(ModelVisitor* const visitor) const override {
4603 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4604 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4605 boolvar_);
4606 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4607 expr_);
4608 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4609 }
4610
4611 private:
4612 BooleanVar* const boolvar_;
4613 IntExpr* const expr_;
4614};
4615
4616void TimesBooleanIntExpr::SetMin(int64_t m) {
4617 switch (boolvar_->RawValue()) {
4618 case 0: {
4619 if (m > 0) {
4620 solver()->Fail();
4621 }
4622 break;
4623 }
4624 case 1: {
4625 expr_->SetMin(m);
4626 break;
4627 }
4628 default: {
4629 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4630 if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4631 boolvar_->SetValue(1);
4632 expr_->SetMin(m);
4633 } else if (m <= 0 && expr_->Max() < m) {
4634 boolvar_->SetValue(0);
4635 }
4636 }
4637 }
4638}
4639
4640void TimesBooleanIntExpr::SetMax(int64_t m) {
4641 switch (boolvar_->RawValue()) {
4642 case 0: {
4643 if (m < 0) {
4644 solver()->Fail();
4645 }
4646 break;
4647 }
4648 case 1: {
4649 expr_->SetMax(m);
4650 break;
4651 }
4652 default: {
4653 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4654 if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4655 boolvar_->SetValue(1);
4656 expr_->SetMax(m);
4657 } else if (m >= 0 && expr_->Min() > m) {
4658 boolvar_->SetValue(0);
4659 }
4660 }
4661 }
4662}
4663
4664void TimesBooleanIntExpr::Range(int64_t* mi, int64_t* ma) {
4665 switch (boolvar_->RawValue()) {
4666 case 0: {
4667 *mi = 0;
4668 *ma = 0;
4669 break;
4670 }
4671 case 1: {
4672 *mi = expr_->Min();
4673 *ma = expr_->Max();
4674 break;
4675 }
4676 default: {
4677 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4678 *mi = std::min(int64_t{0}, expr_->Min());
4679 *ma = std::max(int64_t{0}, expr_->Max());
4680 break;
4681 }
4682 }
4683}
4684
4685void TimesBooleanIntExpr::SetRange(int64_t mi, int64_t ma) {
4686 if (mi > ma) {
4687 solver()->Fail();
4688 }
4689 switch (boolvar_->RawValue()) {
4690 case 0: {
4691 if (mi > 0 || ma < 0) {
4692 solver()->Fail();
4693 }
4694 break;
4695 }
4696 case 1: {
4697 expr_->SetRange(mi, ma);
4698 break;
4699 }
4700 default: {
4701 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4702 if (mi > 0) {
4703 boolvar_->SetValue(1);
4704 expr_->SetMin(mi);
4705 } else if (mi == 0 && expr_->Max() < 0) {
4706 boolvar_->SetValue(0);
4707 }
4708 if (ma < 0) {
4709 boolvar_->SetValue(1);
4710 expr_->SetMax(ma);
4711 } else if (ma == 0 && expr_->Min() > 0) {
4712 boolvar_->SetValue(0);
4713 }
4714 break;
4715 }
4716 }
4717}
4718
4719bool TimesBooleanIntExpr::Bound() const {
4720 return (boolvar_->RawValue() == 0 ||
4721 (expr_->Bound() &&
4722 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4723 expr_->Max() == 0)));
4724}
4725
4726// ----- DivPosIntCstExpr -----
4727
4728class DivPosIntCstExpr : public BaseIntExpr {
4729 public:
4730 DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4731 : BaseIntExpr(s), expr_(e), value_(v) {
4732 CHECK_GE(v, 0);
4733 }
4734 ~DivPosIntCstExpr() override {}
4735
4736 int64_t Min() const override { return expr_->Min() / value_; }
4737
4738 void SetMin(int64_t m) override {
4739 if (m > 0) {
4740 expr_->SetMin(m * value_);
4741 } else {
4742 expr_->SetMin((m - 1) * value_ + 1);
4743 }
4744 }
4745 int64_t Max() const override { return expr_->Max() / value_; }
4746
4747 void SetMax(int64_t m) override {
4748 if (m >= 0) {
4749 expr_->SetMax((m + 1) * value_ - 1);
4750 } else {
4751 expr_->SetMax(m * value_);
4752 }
4753 }
4754
4755 std::string name() const override {
4756 return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4757 }
4758
4759 std::string DebugString() const override {
4760 return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4761 }
4762
4763 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4764
4765 void Accept(ModelVisitor* const visitor) const override {
4766 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4767 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4768 expr_);
4769 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4770 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4771 }
4772
4773 private:
4774 IntExpr* const expr_;
4775 const int64_t value_;
4776};
4777
4778// DivPosIntExpr
4779
4780class DivPosIntExpr : public BaseIntExpr {
4781 public:
4782 DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4783 : BaseIntExpr(s),
4784 num_(num),
4785 denom_(denom),
4786 opp_num_(s->MakeOpposite(num)) {}
4787
4788 ~DivPosIntExpr() override {}
4789
4790 int64_t Min() const override {
4791 return num_->Min() >= 0
4792 ? num_->Min() / denom_->Max()
4793 : (denom_->Min() == 0 ? num_->Min()
4794 : num_->Min() / denom_->Min());
4795 }
4796
4797 int64_t Max() const override {
4798 return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4799 : num_->Max() / denom_->Min())
4800 : num_->Max() / denom_->Max();
4801 }
4802
4803 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4804 num->SetMin(m * denom->Min());
4805 denom->SetMax(num->Max() / m);
4806 }
4807
4808 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
4809 num->SetMax((m + 1) * denom->Max() - 1);
4810 denom->SetMin(num->Min() / (m + 1) + 1);
4811 }
4812
4813 void SetMin(int64_t m) override {
4814 if (m > 0) {
4815 SetPosMin(num_, denom_, m);
4816 } else {
4817 SetPosMax(opp_num_, denom_, -m);
4818 }
4819 }
4820
4821 void SetMax(int64_t m) override {
4822 if (m >= 0) {
4823 SetPosMax(num_, denom_, m);
4824 } else {
4825 SetPosMin(opp_num_, denom_, -m);
4826 }
4827 }
4828
4829 std::string name() const override {
4830 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4831 }
4832 std::string DebugString() const override {
4833 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4834 denom_->DebugString());
4835 }
4836 void WhenRange(Demon* d) override {
4837 num_->WhenRange(d);
4838 denom_->WhenRange(d);
4839 }
4840
4841 void Accept(ModelVisitor* const visitor) const override {
4842 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4843 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4844 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4845 denom_);
4846 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4847 }
4848
4849 private:
4850 IntExpr* const num_;
4851 IntExpr* const denom_;
4852 IntExpr* const opp_num_;
4853};
4854
4855class DivPosPosIntExpr : public BaseIntExpr {
4856 public:
4857 DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4858 : BaseIntExpr(s), num_(num), denom_(denom) {}
4859
4860 ~DivPosPosIntExpr() override {}
4861
4862 int64_t Min() const override {
4863 if (denom_->Max() == 0) {
4864 solver()->Fail();
4865 }
4866 return num_->Min() / denom_->Max();
4867 }
4868
4869 int64_t Max() const override {
4870 if (denom_->Min() == 0) {
4871 return num_->Max();
4872 } else {
4873 return num_->Max() / denom_->Min();
4874 }
4875 }
4876
4877 void SetMin(int64_t m) override {
4878 if (m > 0) {
4879 num_->SetMin(m * denom_->Min());
4880 denom_->SetMax(num_->Max() / m);
4881 }
4882 }
4883
4884 void SetMax(int64_t m) override {
4885 if (m >= 0) {
4886 num_->SetMax((m + 1) * denom_->Max() - 1);
4887 denom_->SetMin(num_->Min() / (m + 1) + 1);
4888 } else {
4889 solver()->Fail();
4890 }
4891 }
4892
4893 std::string name() const override {
4894 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4895 }
4896
4897 std::string DebugString() const override {
4898 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4899 denom_->DebugString());
4900 }
4901
4902 void WhenRange(Demon* d) override {
4903 num_->WhenRange(d);
4904 denom_->WhenRange(d);
4905 }
4906
4907 void Accept(ModelVisitor* const visitor) const override {
4908 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4909 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4910 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4911 denom_);
4912 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4913 }
4914
4915 private:
4916 IntExpr* const num_;
4917 IntExpr* const denom_;
4918};
4919
4920// DivIntExpr
4921
4922class DivIntExpr : public BaseIntExpr {
4923 public:
4924 DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4925 : BaseIntExpr(s),
4926 num_(num),
4927 denom_(denom),
4928 opp_num_(s->MakeOpposite(num)) {}
4929
4930 ~DivIntExpr() override {}
4931
4932 int64_t Min() const override {
4933 const int64_t num_min = num_->Min();
4934 const int64_t num_max = num_->Max();
4935 const int64_t denom_min = denom_->Min();
4936 const int64_t denom_max = denom_->Max();
4937
4938 if (denom_min == 0 && denom_max == 0) {
4939 return std::numeric_limits<int64_t>::max(); // TODO(user): Check this
4940 // convention.
4941 }
4942
4943 if (denom_min >= 0) { // Denominator strictly positive.
4944 DCHECK_GT(denom_max, 0);
4945 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4946 return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4947 } else if (denom_max <= 0) { // Denominator strictly negative.
4948 DCHECK_LT(denom_min, 0);
4949 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4950 return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4951 } else { // Denominator across 0.
4952 return std::min(num_min, -num_max);
4953 }
4954 }
4955
4956 int64_t Max() const override {
4957 const int64_t num_min = num_->Min();
4958 const int64_t num_max = num_->Max();
4959 const int64_t denom_min = denom_->Min();
4960 const int64_t denom_max = denom_->Max();
4961
4962 if (denom_min == 0 && denom_max == 0) {
4963 return std::numeric_limits<int64_t>::min(); // TODO(user): Check this
4964 // convention.
4965 }
4966
4967 if (denom_min >= 0) { // Denominator strictly positive.
4968 DCHECK_GT(denom_max, 0);
4969 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4970 return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4971 } else if (denom_max <= 0) { // Denominator strictly negative.
4972 DCHECK_LT(denom_min, 0);
4973 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4974 return num_min >= 0 ? num_min / denom_min
4975 : -num_min / -adjusted_denom_max;
4976 } else { // Denominator across 0.
4977 return std::max(num_max, -num_min);
4978 }
4979 }
4980
4981 void AdjustDenominator() {
4982 if (denom_->Min() == 0) {
4983 denom_->SetMin(1);
4984 } else if (denom_->Max() == 0) {
4985 denom_->SetMax(-1);
4986 }
4987 }
4988
4989 // m > 0.
4990 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4991 DCHECK_GT(m, 0);
4992 const int64_t num_min = num->Min();
4993 const int64_t num_max = num->Max();
4994 const int64_t denom_min = denom->Min();
4995 const int64_t denom_max = denom->Max();
4996 DCHECK_NE(denom_min, 0);
4997 DCHECK_NE(denom_max, 0);
4998 if (denom_min > 0) { // Denominator strictly positive.
4999 num->SetMin(m * denom_min);
5000 denom->SetMax(num_max / m);
5001 } else if (denom_max < 0) { // Denominator strictly negative.
5002 num->SetMax(m * denom_max);
5003 denom->SetMin(num_min / m);
5004 } else { // Denominator across 0.
5005 if (num_min >= 0) {
5006 num->SetMin(m);
5007 denom->SetRange(1, num_max / m);
5008 } else if (num_max <= 0) {
5009 num->SetMax(-m);
5010 denom->SetRange(num_min / m, -1);
5011 } else {
5012 if (m > -num_min) { // Denominator is forced positive.
5013 num->SetMin(m);
5014 denom->SetRange(1, num_max / m);
5015 } else if (m > num_max) { // Denominator is forced negative.
5016 num->SetMax(-m);
5017 denom->SetRange(num_min / m, -1);
5018 } else {
5019 denom->SetRange(num_min / m, num_max / m);
5020 }
5021 }
5022 }
5023 }
5024
5025 // m >= 0.
5026 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
5027 DCHECK_GE(m, 0);
5028 const int64_t num_min = num->Min();
5029 const int64_t num_max = num->Max();
5030 const int64_t denom_min = denom->Min();
5031 const int64_t denom_max = denom->Max();
5032 DCHECK_NE(denom_min, 0);
5033 DCHECK_NE(denom_max, 0);
5034 if (denom_min > 0) { // Denominator strictly positive.
5035 num->SetMax((m + 1) * denom_max - 1);
5036 denom->SetMin((num_min / (m + 1)) + 1);
5037 } else if (denom_max < 0) {
5038 num->SetMin((m + 1) * denom_min + 1);
5039 denom->SetMax(num_max / (m + 1) - 1);
5040 } else if (num_min > (m + 1) * denom_max - 1) {
5041 denom->SetMax(-1);
5042 } else if (num_max < (m + 1) * denom_min + 1) {
5043 denom->SetMin(1);
5044 }
5045 }
5046
5047 void SetMin(int64_t m) override {
5048 AdjustDenominator();
5049 if (m > 0) {
5050 SetPosMin(num_, denom_, m);
5051 } else {
5052 SetPosMax(opp_num_, denom_, -m);
5053 }
5054 }
5055
5056 void SetMax(int64_t m) override {
5057 AdjustDenominator();
5058 if (m >= 0) {
5059 SetPosMax(num_, denom_, m);
5060 } else {
5061 SetPosMin(opp_num_, denom_, -m);
5062 }
5063 }
5064
5065 std::string name() const override {
5066 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5067 }
5068 std::string DebugString() const override {
5069 return absl::StrFormat("(%s div %s)", num_->DebugString(),
5070 denom_->DebugString());
5071 }
5072 void WhenRange(Demon* d) override {
5073 num_->WhenRange(d);
5074 denom_->WhenRange(d);
5075 }
5076
5077 void Accept(ModelVisitor* const visitor) const override {
5078 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5079 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5080 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5081 denom_);
5082 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5083 }
5084
5085 private:
5086 IntExpr* const num_;
5087 IntExpr* const denom_;
5088 IntExpr* const opp_num_;
5089};
5090
5091// ----- IntAbs And IntAbsConstraint ------
5092
5093class IntAbsConstraint : public CastConstraint {
5094 public:
5095 IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5096 : CastConstraint(s, target), sub_(sub) {}
5097
5098 ~IntAbsConstraint() override {}
5099
5100 void Post() override {
5101 Demon* const sub_demon = MakeConstraintDemon0(
5102 solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5103 sub_->WhenRange(sub_demon);
5104 Demon* const target_demon = MakeConstraintDemon0(
5105 solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5106 target_var_->WhenRange(target_demon);
5107 }
5108
5109 void InitialPropagate() override {
5110 PropagateSub();
5111 PropagateTarget();
5112 }
5113
5114 void PropagateSub() {
5115 const int64_t smin = sub_->Min();
5116 const int64_t smax = sub_->Max();
5117 if (smax <= 0) {
5118 target_var_->SetRange(-smax, -smin);
5119 } else if (smin >= 0) {
5120 target_var_->SetRange(smin, smax);
5121 } else {
5122 target_var_->SetRange(0, std::max(-smin, smax));
5123 }
5124 }
5125
5126 void PropagateTarget() {
5127 const int64_t target_max = target_var_->Max();
5128 sub_->SetRange(-target_max, target_max);
5129 const int64_t target_min = target_var_->Min();
5130 if (target_min > 0) {
5131 if (sub_->Min() > -target_min) {
5132 sub_->SetMin(target_min);
5133 } else if (sub_->Max() < target_min) {
5134 sub_->SetMax(-target_min);
5135 }
5136 }
5137 }
5138
5139 std::string DebugString() const override {
5140 return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5141 target_var_->DebugString());
5142 }
5143
5144 void Accept(ModelVisitor* const visitor) const override {
5145 visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5146 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5147 sub_);
5148 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5149 target_var_);
5150 visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5151 }
5152
5153 private:
5154 IntVar* const sub_;
5155};
5156
5157class IntAbs : public BaseIntExpr {
5158 public:
5159 IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5160
5161 ~IntAbs() override {}
5162
5163 int64_t Min() const override {
5164 int64_t emin = 0;
5165 int64_t emax = 0;
5166 expr_->Range(&emin, &emax);
5167 if (emin >= 0) {
5168 return emin;
5169 }
5170 if (emax <= 0) {
5171 return -emax;
5172 }
5173 return 0;
5174 }
5175
5176 void SetMin(int64_t m) override {
5177 if (m > 0) {
5178 int64_t emin = 0;
5179 int64_t emax = 0;
5180 expr_->Range(&emin, &emax);
5181 if (emin > -m) {
5182 expr_->SetMin(m);
5183 } else if (emax < m) {
5184 expr_->SetMax(-m);
5185 }
5186 }
5187 }
5188
5189 int64_t Max() const override {
5190 int64_t emin = 0;
5191 int64_t emax = 0;
5192 expr_->Range(&emin, &emax);
5193 return std::max(-emin, emax);
5194 }
5195
5196 void SetMax(int64_t m) override { expr_->SetRange(-m, m); }
5197
5198 void SetRange(int64_t mi, int64_t ma) override {
5199 expr_->SetRange(-ma, ma);
5200 if (mi > 0) {
5201 int64_t emin = 0;
5202 int64_t emax = 0;
5203 expr_->Range(&emin, &emax);
5204 if (emin > -mi) {
5205 expr_->SetMin(mi);
5206 } else if (emax < mi) {
5207 expr_->SetMax(-mi);
5208 }
5209 }
5210 }
5211
5212 void Range(int64_t* mi, int64_t* ma) override {
5213 int64_t emin = 0;
5214 int64_t emax = 0;
5215 expr_->Range(&emin, &emax);
5216 if (emin >= 0) {
5217 *mi = emin;
5218 *ma = emax;
5219 } else if (emax <= 0) {
5220 *mi = -emax;
5221 *ma = -emin;
5222 } else {
5223 *mi = 0;
5224 *ma = std::max(-emin, emax);
5225 }
5226 }
5227
5228 bool Bound() const override { return expr_->Bound(); }
5229
5230 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5231
5232 std::string name() const override {
5233 return absl::StrFormat("IntAbs(%s)", expr_->name());
5234 }
5235
5236 std::string DebugString() const override {
5237 return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5238 }
5239
5240 void Accept(ModelVisitor* const visitor) const override {
5241 visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5242 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5243 expr_);
5244 visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5245 }
5246
5247 IntVar* CastToVar() override {
5248 int64_t min_value = 0;
5249 int64_t max_value = 0;
5250 Range(&min_value, &max_value);
5251 Solver* const s = solver();
5252 const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5253 IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5254 CastConstraint* const ct =
5255 s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5256 s->AddCastConstraint(ct, target, this);
5257 return target;
5258 }
5259
5260 private:
5261 IntExpr* const expr_;
5262};
5263
5264// ----- Square -----
5265
5266// TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5267class IntSquare : public BaseIntExpr {
5268 public:
5269 IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5270 ~IntSquare() override {}
5271
5272 int64_t Min() const override {
5273 const int64_t emin = expr_->Min();
5274 if (emin >= 0) {
5275 return emin >= std::numeric_limits<int32_t>::max()
5276 ? std::numeric_limits<int64_t>::max()
5277 : emin * emin;
5278 }
5279 const int64_t emax = expr_->Max();
5280 if (emax < 0) {
5281 return emax <= -std::numeric_limits<int32_t>::max()
5282 ? std::numeric_limits<int64_t>::max()
5283 : emax * emax;
5284 }
5285 return 0LL;
5286 }
5287 void SetMin(int64_t m) override {
5288 if (m <= 0) {
5289 return;
5290 }
5291 // TODO(user): What happens if m is kint64max?
5292 const int64_t emin = expr_->Min();
5293 const int64_t emax = expr_->Max();
5294 const int64_t root =
5295 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5296 if (emin >= 0) {
5297 expr_->SetMin(root);
5298 } else if (emax <= 0) {
5299 expr_->SetMax(-root);
5300 } else if (expr_->IsVar()) {
5301 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5302 }
5303 }
5304 int64_t Max() const override {
5305 const int64_t emax = expr_->Max();
5306 const int64_t emin = expr_->Min();
5307 if (emax >= std::numeric_limits<int32_t>::max() ||
5308 emin <= -std::numeric_limits<int32_t>::max()) {
5309 return std::numeric_limits<int64_t>::max();
5310 }
5311 return std::max(emin * emin, emax * emax);
5312 }
5313 void SetMax(int64_t m) override {
5314 if (m < 0) {
5315 solver()->Fail();
5316 }
5317 if (m == std::numeric_limits<int64_t>::max()) {
5318 return;
5319 }
5320 const int64_t root =
5321 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5322 expr_->SetRange(-root, root);
5323 }
5324 bool Bound() const override { return expr_->Bound(); }
5325 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5326 std::string name() const override {
5327 return absl::StrFormat("IntSquare(%s)", expr_->name());
5328 }
5329 std::string DebugString() const override {
5330 return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5331 }
5332
5333 void Accept(ModelVisitor* const visitor) const override {
5334 visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5335 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5336 expr_);
5337 visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5338 }
5339
5340 IntExpr* expr() const { return expr_; }
5341
5342 protected:
5343 IntExpr* const expr_;
5344};
5345
5346class PosIntSquare : public IntSquare {
5347 public:
5348 PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5349 ~PosIntSquare() override {}
5350
5351 int64_t Min() const override {
5352 const int64_t emin = expr_->Min();
5353 return emin >= std::numeric_limits<int32_t>::max()
5354 ? std::numeric_limits<int64_t>::max()
5355 : emin * emin;
5356 }
5357 void SetMin(int64_t m) override {
5358 if (m <= 0) {
5359 return;
5360 }
5361 int64_t root = static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5362 if (CapProd(root, root) < m) {
5363 root++;
5364 }
5365 expr_->SetMin(root);
5366 }
5367 int64_t Max() const override {
5368 const int64_t emax = expr_->Max();
5369 return emax >= std::numeric_limits<int32_t>::max()
5370 ? std::numeric_limits<int64_t>::max()
5371 : emax * emax;
5372 }
5373 void SetMax(int64_t m) override {
5374 if (m < 0) {
5375 solver()->Fail();
5376 }
5377 if (m == std::numeric_limits<int64_t>::max()) {
5378 return;
5379 }
5380 int64_t root = static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5381 if (CapProd(root, root) > m) {
5382 root--;
5383 }
5384
5385 expr_->SetMax(root);
5386 }
5387};
5388
5389// ----- EvenPower -----
5390
5391int64_t IntPower(int64_t value, int64_t power) {
5392 int64_t result = value;
5393 // TODO(user): Speed that up.
5394 for (int i = 1; i < power; ++i) {
5395 result *= value;
5396 }
5397 return result;
5398}
5399
5400int64_t OverflowLimit(int64_t power) {
5401 return static_cast<int64_t>(floor(exp(
5402 log(static_cast<double>(std::numeric_limits<int64_t>::max())) / power)));
5403}
5404
5405class BasePower : public BaseIntExpr {
5406 public:
5407 BasePower(Solver* const s, IntExpr* const e, int64_t n)
5408 : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5409 CHECK_GT(n, 0);
5410 }
5411
5412 ~BasePower() override {}
5413
5414 bool Bound() const override { return expr_->Bound(); }
5415
5416 IntExpr* expr() const { return expr_; }
5417
5418 int64_t exponant() const { return pow_; }
5419
5420 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5421
5422 std::string name() const override {
5423 return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5424 }
5425
5426 std::string DebugString() const override {
5427 return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5428 }
5429
5430 void Accept(ModelVisitor* const visitor) const override {
5431 visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5432 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5433 expr_);
5434 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5435 visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5436 }
5437
5438 protected:
5439 int64_t Pown(int64_t value) const {
5440 if (value >= limit_) {
5441 return std::numeric_limits<int64_t>::max();
5442 }
5443 if (value <= -limit_) {
5444 if (pow_ % 2 == 0) {
5445 return std::numeric_limits<int64_t>::max();
5446 } else {
5447 return std::numeric_limits<int64_t>::min();
5448 }
5449 }
5450 return IntPower(value, pow_);
5451 }
5452
5453 int64_t SqrnDown(int64_t value) const {
5454 if (value == std::numeric_limits<int64_t>::min()) {
5455 return std::numeric_limits<int64_t>::min();
5456 }
5457 if (value == std::numeric_limits<int64_t>::max()) {
5458 return std::numeric_limits<int64_t>::max();
5459 }
5460 int64_t res = 0;
5461 const double d_value = static_cast<double>(value);
5462 if (value >= 0) {
5463 const double sq = exp(log(d_value) / pow_);
5464 res = static_cast<int64_t>(floor(sq));
5465 } else {
5466 CHECK_EQ(1, pow_ % 2);
5467 const double sq = exp(log(-d_value) / pow_);
5468 res = -static_cast<int64_t>(ceil(sq));
5469 }
5470 const int64_t pow_res = Pown(res + 1);
5471 if (pow_res <= value) {
5472 return res + 1;
5473 } else {
5474 return res;
5475 }
5476 }
5477
5478 int64_t SqrnUp(int64_t value) const {
5479 if (value == std::numeric_limits<int64_t>::min()) {
5480 return std::numeric_limits<int64_t>::min();
5481 }
5482 if (value == std::numeric_limits<int64_t>::max()) {
5483 return std::numeric_limits<int64_t>::max();
5484 }
5485 int64_t res = 0;
5486 const double d_value = static_cast<double>(value);
5487 if (value >= 0) {
5488 const double sq = exp(log(d_value) / pow_);
5489 res = static_cast<int64_t>(ceil(sq));
5490 } else {
5491 CHECK_EQ(1, pow_ % 2);
5492 const double sq = exp(log(-d_value) / pow_);
5493 res = -static_cast<int64_t>(floor(sq));
5494 }
5495 const int64_t pow_res = Pown(res - 1);
5496 if (pow_res >= value) {
5497 return res - 1;
5498 } else {
5499 return res;
5500 }
5501 }
5502
5503 IntExpr* const expr_;
5504 const int64_t pow_;
5505 const int64_t limit_;
5506};
5507
5508class IntEvenPower : public BasePower {
5509 public:
5510 IntEvenPower(Solver* const s, IntExpr* const e, int64_t n)
5511 : BasePower(s, e, n) {
5512 CHECK_EQ(0, n % 2);
5513 }
5514
5515 ~IntEvenPower() override {}
5516
5517 int64_t Min() const override {
5518 int64_t emin = 0;
5519 int64_t emax = 0;
5520 expr_->Range(&emin, &emax);
5521 if (emin >= 0) {
5522 return Pown(emin);
5523 }
5524 if (emax < 0) {
5525 return Pown(emax);
5526 }
5527 return 0LL;
5528 }
5529 void SetMin(int64_t m) override {
5530 if (m <= 0) {
5531 return;
5532 }
5533 int64_t emin = 0;
5534 int64_t emax = 0;
5535 expr_->Range(&emin, &emax);
5536 const int64_t root = SqrnUp(m);
5537 if (emin > -root) {
5538 expr_->SetMin(root);
5539 } else if (emax < root) {
5540 expr_->SetMax(-root);
5541 } else if (expr_->IsVar()) {
5542 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5543 }
5544 }
5545
5546 int64_t Max() const override {
5547 return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5548 }
5549
5550 void SetMax(int64_t m) override {
5551 if (m < 0) {
5552 solver()->Fail();
5553 }
5554 if (m == std::numeric_limits<int64_t>::max()) {
5555 return;
5556 }
5557 const int64_t root = SqrnDown(m);
5558 expr_->SetRange(-root, root);
5559 }
5560};
5561
5562class PosIntEvenPower : public BasePower {
5563 public:
5564 PosIntEvenPower(Solver* const s, IntExpr* const e, int64_t pow)
5565 : BasePower(s, e, pow) {
5566 CHECK_EQ(0, pow % 2);
5567 }
5568
5569 ~PosIntEvenPower() override {}
5570
5571 int64_t Min() const override { return Pown(expr_->Min()); }
5572
5573 void SetMin(int64_t m) override {
5574 if (m <= 0) {
5575 return;
5576 }
5577 expr_->SetMin(SqrnUp(m));
5578 }
5579 int64_t Max() const override { return Pown(expr_->Max()); }
5580
5581 void SetMax(int64_t m) override {
5582 if (m < 0) {
5583 solver()->Fail();
5584 }
5585 if (m == std::numeric_limits<int64_t>::max()) {
5586 return;
5587 }
5588 expr_->SetMax(SqrnDown(m));
5589 }
5590};
5591
5592class IntOddPower : public BasePower {
5593 public:
5594 IntOddPower(Solver* const s, IntExpr* const e, int64_t n)
5595 : BasePower(s, e, n) {
5596 CHECK_EQ(1, n % 2);
5597 }
5598
5599 ~IntOddPower() override {}
5600
5601 int64_t Min() const override { return Pown(expr_->Min()); }
5602
5603 void SetMin(int64_t m) override { expr_->SetMin(SqrnUp(m)); }
5604
5605 int64_t Max() const override { return Pown(expr_->Max()); }
5606
5607 void SetMax(int64_t m) override { expr_->SetMax(SqrnDown(m)); }
5608};
5609
5610// ----- Min(expr, expr) -----
5611
5612class MinIntExpr : public BaseIntExpr {
5613 public:
5614 MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5615 : BaseIntExpr(s), left_(l), right_(r) {}
5616 ~MinIntExpr() override {}
5617 int64_t Min() const override {
5618 const int64_t lmin = left_->Min();
5619 const int64_t rmin = right_->Min();
5620 return std::min(lmin, rmin);
5621 }
5622 void SetMin(int64_t m) override {
5623 left_->SetMin(m);
5624 right_->SetMin(m);
5625 }
5626 int64_t Max() const override {
5627 const int64_t lmax = left_->Max();
5628 const int64_t rmax = right_->Max();
5629 return std::min(lmax, rmax);
5630 }
5631 void SetMax(int64_t m) override {
5632 if (left_->Min() > m) {
5633 right_->SetMax(m);
5634 }
5635 if (right_->Min() > m) {
5636 left_->SetMax(m);
5637 }
5638 }
5639 std::string name() const override {
5640 return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5641 }
5642 std::string DebugString() const override {
5643 return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5644 right_->DebugString());
5645 }
5646 void WhenRange(Demon* d) override {
5647 left_->WhenRange(d);
5648 right_->WhenRange(d);
5649 }
5650
5651 void Accept(ModelVisitor* const visitor) const override {
5652 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5653 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5654 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5655 right_);
5656 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5657 }
5658
5659 private:
5660 IntExpr* const left_;
5661 IntExpr* const right_;
5662};
5663
5664// ----- Min(expr, constant) -----
5665
5666class MinCstIntExpr : public BaseIntExpr {
5667 public:
5668 MinCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5669 : BaseIntExpr(s), expr_(e), value_(v) {}
5670
5671 ~MinCstIntExpr() override {}
5672
5673 int64_t Min() const override { return std::min(expr_->Min(), value_); }
5674
5675 void SetMin(int64_t m) override {
5676 if (m > value_) {
5677 solver()->Fail();
5678 }
5679 expr_->SetMin(m);
5680 }
5681
5682 int64_t Max() const override { return std::min(expr_->Max(), value_); }
5683
5684 void SetMax(int64_t m) override {
5685 if (value_ > m) {
5686 expr_->SetMax(m);
5687 }
5688 }
5689
5690 bool Bound() const override {
5691 return (expr_->Bound() || expr_->Min() >= value_);
5692 }
5693
5694 std::string name() const override {
5695 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5696 }
5697
5698 std::string DebugString() const override {
5699 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5700 value_);
5701 }
5702
5703 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5704
5705 void Accept(ModelVisitor* const visitor) const override {
5706 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5707 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5708 expr_);
5709 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5710 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5711 }
5712
5713 private:
5714 IntExpr* const expr_;
5715 const int64_t value_;
5716};
5717
5718// ----- Max(expr, expr) -----
5719
5720class MaxIntExpr : public BaseIntExpr {
5721 public:
5722 MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5723 : BaseIntExpr(s), left_(l), right_(r) {}
5724
5725 ~MaxIntExpr() override {}
5726
5727 int64_t Min() const override { return std::max(left_->Min(), right_->Min()); }
5728
5729 void SetMin(int64_t m) override {
5730 if (left_->Max() < m) {
5731 right_->SetMin(m);
5732 } else {
5733 if (right_->Max() < m) {
5734 left_->SetMin(m);
5735 }
5736 }
5737 }
5738
5739 int64_t Max() const override { return std::max(left_->Max(), right_->Max()); }
5740
5741 void SetMax(int64_t m) override {
5742 left_->SetMax(m);
5743 right_->SetMax(m);
5744 }
5745
5746 std::string name() const override {
5747 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5748 }
5749
5750 std::string DebugString() const override {
5751 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5752 right_->DebugString());
5753 }
5754
5755 void WhenRange(Demon* d) override {
5756 left_->WhenRange(d);
5757 right_->WhenRange(d);
5758 }
5759
5760 void Accept(ModelVisitor* const visitor) const override {
5761 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5762 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5763 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5764 right_);
5765 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5766 }
5767
5768 private:
5769 IntExpr* const left_;
5770 IntExpr* const right_;
5771};
5772
5773// ----- Max(expr, constant) -----
5774
5775class MaxCstIntExpr : public BaseIntExpr {
5776 public:
5777 MaxCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5778 : BaseIntExpr(s), expr_(e), value_(v) {}
5779
5780 ~MaxCstIntExpr() override {}
5781
5782 int64_t Min() const override { return std::max(expr_->Min(), value_); }
5783
5784 void SetMin(int64_t m) override {
5785 if (value_ < m) {
5786 expr_->SetMin(m);
5787 }
5788 }
5789
5790 int64_t Max() const override { return std::max(expr_->Max(), value_); }
5791
5792 void SetMax(int64_t m) override {
5793 if (m < value_) {
5794 solver()->Fail();
5795 }
5796 expr_->SetMax(m);
5797 }
5798
5799 bool Bound() const override {
5800 return (expr_->Bound() || expr_->Max() <= value_);
5801 }
5802
5803 std::string name() const override {
5804 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5805 }
5806
5807 std::string DebugString() const override {
5808 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5809 value_);
5810 }
5811
5812 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5813
5814 void Accept(ModelVisitor* const visitor) const override {
5815 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5816 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5817 expr_);
5818 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5819 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5820 }
5821
5822 private:
5823 IntExpr* const expr_;
5824 const int64_t value_;
5825};
5826
5827// ----- Convex Piecewise -----
5828
5829// This class is a very simple convex piecewise linear function. The
5830// argument of the function is the expression. Between early_date and
5831// late_date, the value of the function is 0. Before early date, it
5832// is affine and the cost is early_cost * (early_date - x). After
5833// late_date, the cost is late_cost * (x - late_date).
5834
5835class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5836 public:
5837 SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64_t ec,
5838 int64_t ed, int64_t ld, int64_t lc)
5839 : BaseIntExpr(s),
5840 expr_(e),
5841 early_cost_(ec),
5842 early_date_(ec == 0 ? std::numeric_limits<int64_t>::min() : ed),
5843 late_date_(lc == 0 ? std::numeric_limits<int64_t>::max() : ld),
5844 late_cost_(lc) {
5845 DCHECK_GE(ec, int64_t{0});
5846 DCHECK_GE(lc, int64_t{0});
5847 DCHECK_GE(ld, ed);
5848
5849 // If the penalty is 0, we can push the "comfort zone or zone of no cost
5850 // towards infinity.
5851 }
5852
5853 ~SimpleConvexPiecewiseExpr() override {}
5854
5855 int64_t Min() const override {
5856 const int64_t vmin = expr_->Min();
5857 const int64_t vmax = expr_->Max();
5858 if (vmin >= late_date_) {
5859 return (vmin - late_date_) * late_cost_;
5860 } else if (vmax <= early_date_) {
5861 return (early_date_ - vmax) * early_cost_;
5862 } else {
5863 return 0LL;
5864 }
5865 }
5866
5867 void SetMin(int64_t m) override {
5868 if (m <= 0) {
5869 return;
5870 }
5871 int64_t vmin = 0;
5872 int64_t vmax = 0;
5873 expr_->Range(&vmin, &vmax);
5874
5875 const int64_t rb =
5876 (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5877 const int64_t lb =
5878 (early_cost_ == 0 ? vmin
5879 : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5880
5881 if (expr_->IsVar()) {
5882 expr_->Var()->RemoveInterval(lb, rb);
5883 }
5884 }
5885
5886 int64_t Max() const override {
5887 const int64_t vmin = expr_->Min();
5888 const int64_t vmax = expr_->Max();
5889 const int64_t mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5890 const int64_t ml =
5891 vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5892 return std::max(mr, ml);
5893 }
5894
5895 void SetMax(int64_t m) override {
5896 if (m < 0) {
5897 solver()->Fail();
5898 }
5899 if (late_cost_ != 0LL) {
5900 const int64_t rb = late_date_ + PosIntDivDown(m, late_cost_);
5901 if (early_cost_ != 0LL) {
5902 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5903 expr_->SetRange(lb, rb);
5904 } else {
5905 expr_->SetMax(rb);
5906 }
5907 } else {
5908 if (early_cost_ != 0LL) {
5909 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5910 expr_->SetMin(lb);
5911 }
5912 }
5913 }
5914
5915 std::string name() const override {
5916 return absl::StrFormat(
5917 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5918 expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5919 }
5920
5921 std::string DebugString() const override {
5922 return absl::StrFormat(
5923 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5924 expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5925 }
5926
5927 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5928
5929 void Accept(ModelVisitor* const visitor) const override {
5930 visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5931 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5932 expr_);
5933 visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5934 early_cost_);
5935 visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5936 early_date_);
5937 visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5938 visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5939 visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5940 }
5941
5942 private:
5943 IntExpr* const expr_;
5944 const int64_t early_cost_;
5945 const int64_t early_date_;
5946 const int64_t late_date_;
5947 const int64_t late_cost_;
5948};
5949
5950// ----- Semi Continuous -----
5951
5952class SemiContinuousExpr : public BaseIntExpr {
5953 public:
5954 SemiContinuousExpr(Solver* const s, IntExpr* const e, int64_t fixed_charge,
5955 int64_t step)
5956 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5957 DCHECK_GE(fixed_charge, int64_t{0});
5958 DCHECK_GT(step, int64_t{0});
5959 }
5960
5961 ~SemiContinuousExpr() override {}
5962
5963 int64_t Value(int64_t x) const {
5964 if (x <= 0) {
5965 return 0;
5966 } else {
5967 return CapAdd(fixed_charge_, CapProd(x, step_));
5968 }
5969 }
5970
5971 int64_t Min() const override { return Value(expr_->Min()); }
5972
5973 void SetMin(int64_t m) override {
5974 if (m >= CapAdd(fixed_charge_, step_)) {
5975 const int64_t y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5976 expr_->SetMin(y);
5977 } else if (m > 0) {
5978 expr_->SetMin(1);
5979 }
5980 }
5981
5982 int64_t Max() const override { return Value(expr_->Max()); }
5983
5984 void SetMax(int64_t m) override {
5985 if (m < 0) {
5986 solver()->Fail();
5987 }
5988 if (m == std::numeric_limits<int64_t>::max()) {
5989 return;
5990 }
5991 if (m < CapAdd(fixed_charge_, step_)) {
5992 expr_->SetMax(0);
5993 } else {
5994 const int64_t y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5995 expr_->SetMax(y);
5996 }
5997 }
5998
5999 std::string name() const override {
6000 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
6001 expr_->name(), fixed_charge_, step_);
6002 }
6003
6004 std::string DebugString() const override {
6005 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
6006 expr_->DebugString(), fixed_charge_, step_);
6007 }
6008
6009 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6010
6011 void Accept(ModelVisitor* const visitor) const override {
6012 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6013 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6014 expr_);
6015 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6016 fixed_charge_);
6017 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
6018 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6019 }
6020
6021 private:
6022 IntExpr* const expr_;
6023 const int64_t fixed_charge_;
6024 const int64_t step_;
6025};
6026
6027class SemiContinuousStepOneExpr : public BaseIntExpr {
6028 public:
6029 SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
6030 int64_t fixed_charge)
6031 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6032 DCHECK_GE(fixed_charge, int64_t{0});
6033 }
6034
6035 ~SemiContinuousStepOneExpr() override {}
6036
6037 int64_t Value(int64_t x) const {
6038 if (x <= 0) {
6039 return 0;
6040 } else {
6041 return fixed_charge_ + x;
6042 }
6043 }
6044
6045 int64_t Min() const override { return Value(expr_->Min()); }
6046
6047 void SetMin(int64_t m) override {
6048 if (m >= fixed_charge_ + 1) {
6049 expr_->SetMin(m - fixed_charge_);
6050 } else if (m > 0) {
6051 expr_->SetMin(1);
6052 }
6053 }
6054
6055 int64_t Max() const override { return Value(expr_->Max()); }
6056
6057 void SetMax(int64_t m) override {
6058 if (m < 0) {
6059 solver()->Fail();
6060 }
6061 if (m < fixed_charge_ + 1) {
6062 expr_->SetMax(0);
6063 } else {
6064 expr_->SetMax(m - fixed_charge_);
6065 }
6066 }
6067
6068 std::string name() const override {
6069 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6070 expr_->name(), fixed_charge_);
6071 }
6072
6073 std::string DebugString() const override {
6074 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6075 expr_->DebugString(), fixed_charge_);
6076 }
6077
6078 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6079
6080 void Accept(ModelVisitor* const visitor) const override {
6081 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6082 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6083 expr_);
6084 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6085 fixed_charge_);
6086 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6087 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6088 }
6089
6090 private:
6091 IntExpr* const expr_;
6092 const int64_t fixed_charge_;
6093};
6094
6095class SemiContinuousStepZeroExpr : public BaseIntExpr {
6096 public:
6097 SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6098 int64_t fixed_charge)
6099 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6100 DCHECK_GT(fixed_charge, int64_t{0});
6101 }
6102
6103 ~SemiContinuousStepZeroExpr() override {}
6104
6105 int64_t Value(int64_t x) const {
6106 if (x <= 0) {
6107 return 0;
6108 } else {
6109 return fixed_charge_;
6110 }
6111 }
6112
6113 int64_t Min() const override { return Value(expr_->Min()); }
6114
6115 void SetMin(int64_t m) override {
6116 if (m >= fixed_charge_) {
6117 solver()->Fail();
6118 } else if (m > 0) {
6119 expr_->SetMin(1);
6120 }
6121 }
6122
6123 int64_t Max() const override { return Value(expr_->Max()); }
6124
6125 void SetMax(int64_t m) override {
6126 if (m < 0) {
6127 solver()->Fail();
6128 }
6129 if (m < fixed_charge_) {
6130 expr_->SetMax(0);
6131 }
6132 }
6133
6134 std::string name() const override {
6135 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6136 expr_->name(), fixed_charge_);
6137 }
6138
6139 std::string DebugString() const override {
6140 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6141 expr_->DebugString(), fixed_charge_);
6142 }
6143
6144 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6145
6146 void Accept(ModelVisitor* const visitor) const override {
6147 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6148 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6149 expr_);
6150 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6151 fixed_charge_);
6152 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6153 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6154 }
6155
6156 private:
6157 IntExpr* const expr_;
6158 const int64_t fixed_charge_;
6159};
6160
6161// This constraints links an expression and the variable it is casted into
6162class LinkExprAndVar : public CastConstraint {
6163 public:
6164 LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6165 : CastConstraint(s, var), expr_(expr) {}
6166
6167 ~LinkExprAndVar() override {}
6168
6169 void Post() override {
6170 Solver* const s = solver();
6171 Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6172 expr_->WhenRange(d);
6173 target_var_->WhenRange(d);
6174 }
6175
6176 void InitialPropagate() override {
6177 expr_->SetRange(target_var_->Min(), target_var_->Max());
6178 int64_t l, u;
6179 expr_->Range(&l, &u);
6180 target_var_->SetRange(l, u);
6181 }
6182
6183 std::string DebugString() const override {
6184 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6185 target_var_->DebugString());
6186 }
6187
6188 void Accept(ModelVisitor* const visitor) const override {
6189 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6190 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6191 expr_);
6192 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6193 target_var_);
6194 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6195 }
6196
6197 private:
6198 IntExpr* const expr_;
6199};
6200
6201// ----- Conditional Expression -----
6202
6203class ExprWithEscapeValue : public BaseIntExpr {
6204 public:
6205 ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6206 int64_t unperformed_value)
6207 : BaseIntExpr(s),
6208 condition_(c),
6209 expression_(e),
6210 unperformed_value_(unperformed_value) {}
6211
6212 // This type is neither copyable nor movable.
6213 ExprWithEscapeValue(const ExprWithEscapeValue&) = delete;
6214 ExprWithEscapeValue& operator=(const ExprWithEscapeValue&) = delete;
6215
6216 ~ExprWithEscapeValue() override {}
6217
6218 int64_t Min() const override {
6219 if (condition_->Min() == 1) {
6220 return expression_->Min();
6221 } else if (condition_->Max() == 1) {
6222 return std::min(unperformed_value_, expression_->Min());
6223 } else {
6224 return unperformed_value_;
6225 }
6226 }
6227
6228 void SetMin(int64_t m) override {
6229 if (m > unperformed_value_) {
6230 condition_->SetValue(1);
6231 expression_->SetMin(m);
6232 } else if (condition_->Min() == 1) {
6233 expression_->SetMin(m);
6234 } else if (m > expression_->Max()) {
6235 condition_->SetValue(0);
6236 }
6237 }
6238
6239 int64_t Max() const override {
6240 if (condition_->Min() == 1) {
6241 return expression_->Max();
6242 } else if (condition_->Max() == 1) {
6243 return std::max(unperformed_value_, expression_->Max());
6244 } else {
6245 return unperformed_value_;
6246 }
6247 }
6248
6249 void SetMax(int64_t m) override {
6250 if (m < unperformed_value_) {
6251 condition_->SetValue(1);
6252 expression_->SetMax(m);
6253 } else if (condition_->Min() == 1) {
6254 expression_->SetMax(m);
6255 } else if (m < expression_->Min()) {
6256 condition_->SetValue(0);
6257 }
6258 }
6259
6260 void SetRange(int64_t mi, int64_t ma) override {
6261 if (ma < unperformed_value_ || mi > unperformed_value_) {
6262 condition_->SetValue(1);
6263 expression_->SetRange(mi, ma);
6264 } else if (condition_->Min() == 1) {
6265 expression_->SetRange(mi, ma);
6266 } else if (ma < expression_->Min() || mi > expression_->Max()) {
6267 condition_->SetValue(0);
6268 }
6269 }
6270
6271 void SetValue(int64_t v) override {
6272 if (v != unperformed_value_) {
6273 condition_->SetValue(1);
6274 expression_->SetValue(v);
6275 } else if (condition_->Min() == 1) {
6276 expression_->SetValue(v);
6277 } else if (v < expression_->Min() || v > expression_->Max()) {
6278 condition_->SetValue(0);
6279 }
6280 }
6281
6282 bool Bound() const override {
6283 return condition_->Max() == 0 || expression_->Bound();
6284 }
6285
6286 void WhenRange(Demon* d) override {
6287 expression_->WhenRange(d);
6288 condition_->WhenBound(d);
6289 }
6290
6291 std::string DebugString() const override {
6292 return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6293 condition_->DebugString(),
6294 expression_->DebugString(), unperformed_value_);
6295 }
6296
6297 void Accept(ModelVisitor* const visitor) const override {
6298 visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6299 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6300 condition_);
6301 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6302 expression_);
6303 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6304 unperformed_value_);
6305 visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6306 }
6307
6308 private:
6309 IntVar* const condition_;
6310 IntExpr* const expression_;
6311 const int64_t unperformed_value_;
6312};
6313
6314// ----- This is a specialized case when the variable exact type is known -----
6315class LinkExprAndDomainIntVar : public CastConstraint {
6316 public:
6317 LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6318 DomainIntVar* const var)
6319 : CastConstraint(s, var),
6320 expr_(expr),
6321 cached_min_(std::numeric_limits<int64_t>::min()),
6322 cached_max_(std::numeric_limits<int64_t>::max()),
6323 fail_stamp_(uint64_t{0}) {}
6324
6325 ~LinkExprAndDomainIntVar() override {}
6326
6327 DomainIntVar* var() const {
6328 return reinterpret_cast<DomainIntVar*>(target_var_);
6329 }
6330
6331 void Post() override {
6332 Solver* const s = solver();
6333 Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6334 expr_->WhenRange(d);
6335 Demon* const target_var_demon = MakeConstraintDemon0(
6336 solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6337 target_var_->WhenRange(target_var_demon);
6338 }
6339
6340 void InitialPropagate() override {
6341 expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6342 expr_->Range(&cached_min_, &cached_max_);
6343 var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6344 }
6345
6346 void Propagate() {
6347 if (var()->min_.Value() > cached_min_ ||
6348 var()->max_.Value() < cached_max_ ||
6349 solver()->fail_stamp() != fail_stamp_) {
6350 InitialPropagate();
6351 fail_stamp_ = solver()->fail_stamp();
6352 }
6353 }
6354
6355 std::string DebugString() const override {
6356 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6357 target_var_->DebugString());
6358 }
6359
6360 void Accept(ModelVisitor* const visitor) const override {
6361 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6362 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6363 expr_);
6364 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6365 target_var_);
6366 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6367 }
6368
6369 private:
6370 IntExpr* const expr_;
6371 int64_t cached_min_;
6372 int64_t cached_max_;
6373 uint64_t fail_stamp_;
6374};
6375} // namespace
6376
6377// ----- Misc -----
6378
6380 return CondRevAlloc(solver(), reversible, new EmptyIterator());
6381}
6383 return CondRevAlloc(solver(), reversible, new RangeIterator(this));
6384}
6385
6386// ----- API -----
6387
6389 DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6390 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6391 dvar->CleanInProcess();
6392}
6393
6394Constraint* SetIsEqual(IntVar* const var, absl::Span<const int64_t> values,
6395 const std::vector<IntVar*>& vars) {
6396 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6397 CHECK(dvar != nullptr);
6398 return dvar->SetIsEqual(values, vars);
6399}
6400
6402 absl::Span<const int64_t> values,
6403 const std::vector<IntVar*>& vars) {
6404 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6405 CHECK(dvar != nullptr);
6406 return dvar->SetIsGreaterOrEqual(values, vars);
6407}
6408
6410 DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6411 BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6412 boolean_var->RestoreValue();
6413}
6414
6415// ----- API -----
6416
6417IntVar* Solver::MakeIntVar(int64_t min, int64_t max, const std::string& name) {
6418 if (min == max) {
6419 return MakeIntConst(min, name);
6420 }
6421 if (min == 0 && max == 1) {
6422 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6423 } else if (CapSub(max, min) == 1) {
6424 const std::string inner_name = "inner_" + name;
6425 return RegisterIntVar(
6426 MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6427 ->VarWithName(name));
6428 } else {
6429 return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6430 }
6431}
6432
6433IntVar* Solver::MakeIntVar(int64_t min, int64_t max) {
6434 return MakeIntVar(min, max, "");
6435}
6436
6437IntVar* Solver::MakeBoolVar(const std::string& name) {
6438 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6439}
6440
6442 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6443}
6444
6445IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values,
6446 const std::string& name) {
6447 DCHECK(!values.empty());
6448 // Fast-track the case where we have a single value.
6449 if (values.size() == 1) return MakeIntConst(values[0], name);
6450 // Sort and remove duplicates.
6451 std::vector<int64_t> unique_sorted_values = values;
6452 gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6453 // Case when we have a single value, after clean-up.
6454 if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6455 // Case when the values are a dense interval of integers.
6456 if (unique_sorted_values.size() ==
6457 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6458 return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6459 name);
6460 }
6461 // Compute the GCD: if it's not 1, we can express the variable's domain as
6462 // the product of the GCD and of a domain with smaller values.
6463 int64_t gcd = 0;
6464 for (const int64_t v : unique_sorted_values) {
6465 if (gcd == 0) {
6466 gcd = std::abs(v);
6467 } else {
6468 gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6469 }
6470 if (gcd == 1) {
6471 // If it's 1, though, we can't do anything special, so we
6472 // immediately return a new DomainIntVar.
6473 return RegisterIntVar(
6474 RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6475 }
6476 }
6477 DCHECK_GT(gcd, 1);
6478 for (int64_t& v : unique_sorted_values) {
6479 DCHECK_EQ(0, v % gcd);
6480 v /= gcd;
6481 }
6482 const std::string new_name = name.empty() ? "" : "inner_" + name;
6483 // Catch the case where the divided values are a dense set of integers.
6484 IntVar* inner_intvar = nullptr;
6485 if (unique_sorted_values.size() ==
6486 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6487 inner_intvar = MakeIntVar(unique_sorted_values.front(),
6488 unique_sorted_values.back(), new_name);
6489 } else {
6490 inner_intvar = RegisterIntVar(
6491 RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6492 }
6493 return MakeProd(inner_intvar, gcd)->Var();
6494}
6495
6496IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values) {
6497 return MakeIntVar(values, "");
6498}
6499
6500IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6501 const std::string& name) {
6502 return MakeIntVar(ToInt64Vector(values), name);
6503}
6504
6505IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6506 return MakeIntVar(values, "");
6507}
6508
6509IntVar* Solver::MakeIntConst(int64_t val, const std::string& name) {
6510 // If IntConst is going to be named after its creation,
6511 // cp_share_int_consts should be set to false otherwise names can potentially
6512 // be overwritten.
6513 if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6514 val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6515 return cached_constants_[val - MIN_CACHED_INT_CONST];
6516 }
6517 return RevAlloc(new IntConst(this, val, name));
6518}
6519
6520IntVar* Solver::MakeIntConst(int64_t val) { return MakeIntConst(val, ""); }
6521
6522// ----- Int Var and associated methods -----
6523
6524namespace {
6525std::string IndexedName(absl::string_view prefix, int index, int max_index) {
6526#if 0
6527#if defined(_MSC_VER)
6528 const int digits = max_index > 0 ?
6529 static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6530 1;
6531#else
6532 const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6533#endif
6534 return absl::StrFormat("%s%0*d", prefix, digits, index);
6535#else
6536 return absl::StrCat(prefix, index);
6537#endif
6538}
6539} // namespace
6540
6541void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6542 const std::string& name,
6543 std::vector<IntVar*>* vars) {
6544 for (int i = 0; i < var_count; ++i) {
6545 vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6546 }
6547}
6548
6549void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6550 std::vector<IntVar*>* vars) {
6551 for (int i = 0; i < var_count; ++i) {
6552 vars->push_back(MakeIntVar(vmin, vmax));
6553 }
6554}
6555
6556IntVar** Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6557 const std::string& name) {
6558 IntVar** vars = new IntVar*[var_count];
6559 for (int i = 0; i < var_count; ++i) {
6560 vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6561 }
6562 return vars;
6563}
6564
6565void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6566 std::vector<IntVar*>* vars) {
6567 for (int i = 0; i < var_count; ++i) {
6568 vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6569 }
6570}
6571
6572void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6573 for (int i = 0; i < var_count; ++i) {
6574 vars->push_back(MakeBoolVar());
6575 }
6576}
6577
6578IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6579 IntVar** vars = new IntVar*[var_count];
6580 for (int i = 0; i < var_count; ++i) {
6581 vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6582 }
6583 return vars;
6584}
6585
6586void Solver::InitCachedIntConstants() {
6587 for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6588 cached_constants_[i - MIN_CACHED_INT_CONST] =
6589 RevAlloc(new IntConst(this, i, "")); // note the empty name
6590 }
6591}
6592
6593IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6594 CHECK_EQ(this, left->solver());
6595 CHECK_EQ(this, right->solver());
6596 if (right->Bound()) {
6597 return MakeSum(left, right->Min());
6598 }
6599 if (left->Bound()) {
6600 return MakeSum(right, left->Min());
6601 }
6602 if (left == right) {
6603 return MakeProd(left, 2);
6604 }
6605 IntExpr* cache = model_cache_->FindExprExprExpression(
6606 left, right, ModelCache::EXPR_EXPR_SUM);
6607 if (cache == nullptr) {
6608 cache = model_cache_->FindExprExprExpression(right, left,
6610 }
6611 if (cache != nullptr) {
6612 return cache;
6613 } else {
6614 IntExpr* const result =
6615 AddOverflows(left->Max(), right->Max()) ||
6616 AddOverflows(left->Min(), right->Min())
6617 ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6618 : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6619 model_cache_->InsertExprExprExpression(result, left, right,
6621 return result;
6622 }
6623}
6624
6625IntExpr* Solver::MakeSum(IntExpr* const expr, int64_t value) {
6626 CHECK_EQ(this, expr->solver());
6627 if (expr->Bound()) {
6628 return MakeIntConst(CapAdd(expr->Min(), value));
6629 }
6630 if (value == 0) {
6631 return expr;
6632 }
6633 IntExpr* result = Cache()->FindExprConstantExpression(
6634 expr, value, ModelCache::EXPR_CONSTANT_SUM);
6635 if (result == nullptr) {
6636 if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6637 !AddOverflows(value, expr->Min())) {
6638 IntVar* const var = expr->Var();
6639 switch (var->VarType()) {
6640 case DOMAIN_INT_VAR: {
6641 result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6642 this, reinterpret_cast<DomainIntVar*>(var), value)));
6643 break;
6644 }
6645 case CONST_VAR: {
6646 result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6647 break;
6648 }
6649 case VAR_ADD_CST: {
6650 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6651 IntVar* const sub_var = add_var->SubVar();
6652 const int64_t new_constant = value + add_var->Constant();
6653 if (new_constant == 0) {
6654 result = sub_var;
6655 } else {
6656 if (sub_var->VarType() == DOMAIN_INT_VAR) {
6657 DomainIntVar* const dvar =
6658 reinterpret_cast<DomainIntVar*>(sub_var);
6659 result = RegisterIntExpr(
6660 RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6661 } else {
6662 result = RegisterIntExpr(
6663 RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6664 }
6665 }
6666 break;
6667 }
6668 case CST_SUB_VAR: {
6669 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6670 IntVar* const sub_var = add_var->SubVar();
6671 const int64_t new_constant = value + add_var->Constant();
6672 result = RegisterIntExpr(
6673 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6674 break;
6675 }
6676 case OPP_VAR: {
6677 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6678 IntVar* const sub_var = add_var->SubVar();
6679 result =
6680 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6681 break;
6682 }
6683 default:
6684 result =
6685 RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6686 }
6687 } else {
6688 result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6689 }
6690 Cache()->InsertExprConstantExpression(result, expr, value,
6692 }
6693 return result;
6694}
6695
6696IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6697 CHECK_EQ(this, left->solver());
6698 CHECK_EQ(this, right->solver());
6699 if (left->Bound()) {
6700 return MakeDifference(left->Min(), right);
6701 }
6702 if (right->Bound()) {
6703 return MakeSum(left, -right->Min());
6704 }
6705 IntExpr* sub_left = nullptr;
6706 IntExpr* sub_right = nullptr;
6707 int64_t left_coef = 1;
6708 int64_t right_coef = 1;
6709 if (IsProduct(left, &sub_left, &left_coef) &&
6710 IsProduct(right, &sub_right, &right_coef)) {
6711 const int64_t abs_gcd =
6712 MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6713 if (abs_gcd != 0 && abs_gcd != 1) {
6714 return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6715 MakeProd(sub_right, right_coef / abs_gcd)),
6716 abs_gcd);
6717 }
6718 }
6719
6720 IntExpr* result = Cache()->FindExprExprExpression(
6722 if (result == nullptr) {
6723 if (!SubOverflows(left->Min(), right->Max()) &&
6724 !SubOverflows(left->Max(), right->Min())) {
6725 result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6726 } else {
6727 result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6728 }
6729 Cache()->InsertExprExprExpression(result, left, right,
6731 }
6732 return result;
6733}
6734
6735// warning: this is 'value - expr'.
6736IntExpr* Solver::MakeDifference(int64_t value, IntExpr* const expr) {
6737 CHECK_EQ(this, expr->solver());
6738 if (expr->Bound()) {
6739 return MakeIntConst(value - expr->Min());
6740 }
6741 if (value == 0) {
6742 return MakeOpposite(expr);
6743 }
6744 IntExpr* result = Cache()->FindExprConstantExpression(
6746 if (result == nullptr) {
6747 if (expr->IsVar() && expr->Min() != std::numeric_limits<int64_t>::min() &&
6748 !SubOverflows(value, expr->Min()) &&
6749 !SubOverflows(value, expr->Max())) {
6750 IntVar* const var = expr->Var();
6751 switch (var->VarType()) {
6752 case VAR_ADD_CST: {
6753 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6754 IntVar* const sub_var = add_var->SubVar();
6755 const int64_t new_constant = value - add_var->Constant();
6756 if (new_constant == 0) {
6757 result = sub_var;
6758 } else {
6759 result = RegisterIntExpr(
6760 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6761 }
6762 break;
6763 }
6764 case CST_SUB_VAR: {
6765 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6766 IntVar* const sub_var = add_var->SubVar();
6767 const int64_t new_constant = value - add_var->Constant();
6768 result = MakeSum(sub_var, new_constant);
6769 break;
6770 }
6771 case OPP_VAR: {
6772 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6773 IntVar* const sub_var = add_var->SubVar();
6774 result = MakeSum(sub_var, value);
6775 break;
6776 }
6777 default:
6778 result =
6779 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6780 }
6781 } else {
6782 result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6783 }
6784 Cache()->InsertExprConstantExpression(result, expr, value,
6786 }
6787 return result;
6788}
6789
6791 CHECK_EQ(this, expr->solver());
6792 if (expr->Bound()) {
6793 return MakeIntConst(CapOpp(expr->Min()));
6794 }
6795 IntExpr* result =
6796 Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6797 if (result == nullptr) {
6798 if (expr->IsVar()) {
6799 result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6800 } else {
6801 result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6802 }
6803 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6804 }
6805 return result;
6806}
6807
6808IntExpr* Solver::MakeProd(IntExpr* const expr, int64_t value) {
6809 CHECK_EQ(this, expr->solver());
6810 IntExpr* result = Cache()->FindExprConstantExpression(
6811 expr, value, ModelCache::EXPR_CONSTANT_PROD);
6812 if (result != nullptr) {
6813 return result;
6814 } else {
6815 IntExpr* m_expr = nullptr;
6816 int64_t coefficient = 1;
6817 if (IsProduct(expr, &m_expr, &coefficient)) {
6818 coefficient = CapProd(coefficient, value);
6819 } else {
6820 m_expr = expr;
6821 coefficient = value;
6822 }
6823 if (m_expr->Bound()) {
6824 return MakeIntConst(CapProd(coefficient, m_expr->Min()));
6825 } else if (coefficient == 1) {
6826 return m_expr;
6827 } else if (coefficient == -1) {
6828 return MakeOpposite(m_expr);
6829 } else if (coefficient > 0) {
6830 if (m_expr->Max() > std::numeric_limits<int64_t>::max() / coefficient ||
6831 m_expr->Min() < std::numeric_limits<int64_t>::min() / coefficient) {
6832 result = RegisterIntExpr(
6833 RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6834 } else {
6835 result = RegisterIntExpr(
6836 RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6837 }
6838 } else if (coefficient == 0) {
6839 result = MakeIntConst(0);
6840 } else { // coefficient < 0.
6841 result = RegisterIntExpr(
6842 RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6843 }
6844 if (m_expr->IsVar() &&
6845 !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6846 result = result->Var();
6847 }
6848 Cache()->InsertExprConstantExpression(result, expr, value,
6850 return result;
6851 }
6852}
6853
6854namespace {
6855void ExtractPower(IntExpr** const expr, int64_t* const exponant) {
6856 if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6857 BasePower* const power = dynamic_cast<BasePower*>(*expr);
6858 *expr = power->expr();
6859 *exponant = power->exponant();
6860 }
6861 if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6862 IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6863 *expr = power->expr();
6864 *exponant = 2;
6865 }
6866 if ((*expr)->IsVar()) {
6867 IntVar* const var = (*expr)->Var();
6868 IntExpr* const sub = var->solver()->CastExpression(var);
6869 if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6870 BasePower* const power = dynamic_cast<BasePower*>(sub);
6871 *expr = power->expr();
6872 *exponant = power->exponant();
6873 }
6874 if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6875 IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6876 *expr = power->expr();
6877 *exponant = 2;
6878 }
6879 }
6880}
6881
6882void ExtractProduct(IntExpr** const expr, int64_t* const coefficient,
6883 bool* modified) {
6884 if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6885 TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6886 *coefficient *= left_prod->Constant();
6887 *expr = left_prod->SubVar();
6888 *modified = true;
6889 } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6890 TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6891 *coefficient *= left_prod->Constant();
6892 *expr = left_prod->Expr();
6893 *modified = true;
6894 }
6895}
6896} // namespace
6897
6898IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6899 if (left->Bound()) {
6900 return MakeProd(right, left->Min());
6901 }
6902
6903 if (right->Bound()) {
6904 return MakeProd(left, right->Min());
6905 }
6906
6907 // ----- Discover squares and powers -----
6908
6909 IntExpr* m_left = left;
6910 IntExpr* m_right = right;
6911 int64_t left_exponant = 1;
6912 int64_t right_exponant = 1;
6913 ExtractPower(&m_left, &left_exponant);
6914 ExtractPower(&m_right, &right_exponant);
6915
6916 if (m_left == m_right) {
6917 return MakePower(m_left, left_exponant + right_exponant);
6918 }
6919
6920 // ----- Discover nested products -----
6921
6922 m_left = left;
6923 m_right = right;
6924 int64_t coefficient = 1;
6925 bool modified = false;
6926
6927 ExtractProduct(&m_left, &coefficient, &modified);
6928 ExtractProduct(&m_right, &coefficient, &modified);
6929 if (modified) {
6930 return MakeProd(MakeProd(m_left, m_right), coefficient);
6931 }
6932
6933 // ----- Standard build -----
6934
6935 CHECK_EQ(this, left->solver());
6936 CHECK_EQ(this, right->solver());
6937 IntExpr* result = model_cache_->FindExprExprExpression(
6938 left, right, ModelCache::EXPR_EXPR_PROD);
6939 if (result == nullptr) {
6940 result = model_cache_->FindExprExprExpression(right, left,
6942 }
6943 if (result != nullptr) {
6944 return result;
6945 }
6946 if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6947 if (right->Min() >= 0) {
6948 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6949 this, reinterpret_cast<BooleanVar*>(left), right)));
6950 } else {
6951 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6952 this, reinterpret_cast<BooleanVar*>(left), right)));
6953 }
6954 } else if (right->IsVar() &&
6955 reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6956 if (left->Min() >= 0) {
6957 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6958 this, reinterpret_cast<BooleanVar*>(right), left)));
6959 } else {
6960 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6961 this, reinterpret_cast<BooleanVar*>(right), left)));
6962 }
6963 } else if (left->Min() >= 0 && right->Min() >= 0) {
6964 if (CapProd(left->Max(), right->Max()) ==
6965 std::numeric_limits<int64_t>::max()) { // Potential overflow.
6966 result =
6967 RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6968 } else {
6969 result =
6970 RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6971 }
6972 } else {
6973 result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6974 }
6975 model_cache_->InsertExprExprExpression(result, left, right,
6977 return result;
6978}
6979
6980IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6981 CHECK(numerator != nullptr);
6982 CHECK(denominator != nullptr);
6983 if (denominator->Bound()) {
6984 return MakeDiv(numerator, denominator->Min());
6985 }
6986 IntExpr* result = model_cache_->FindExprExprExpression(
6987 numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6988 if (result != nullptr) {
6989 return result;
6990 }
6991
6992 if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6993 AddConstraint(MakeNonEquality(denominator, 0));
6994 }
6995
6996 if (denominator->Min() >= 0) {
6997 if (numerator->Min() >= 0) {
6998 result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6999 } else {
7000 result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
7001 }
7002 } else if (denominator->Max() <= 0) {
7003 if (numerator->Max() <= 0) {
7004 result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
7005 MakeOpposite(denominator)));
7006 } else {
7007 result = MakeOpposite(RevAlloc(
7008 new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
7009 }
7010 } else {
7011 result = RevAlloc(new DivIntExpr(this, numerator, denominator));
7012 }
7013 model_cache_->InsertExprExprExpression(result, numerator, denominator,
7015 return result;
7016}
7017
7018IntExpr* Solver::MakeDiv(IntExpr* const expr, int64_t value) {
7019 CHECK(expr != nullptr);
7020 CHECK_EQ(this, expr->solver());
7021 if (expr->Bound()) {
7022 return MakeIntConst(expr->Min() / value);
7023 } else if (value == 1) {
7024 return expr;
7025 } else if (value == -1) {
7026 return MakeOpposite(expr);
7027 } else if (value > 0) {
7028 return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
7029 } else if (value == 0) {
7030 LOG(FATAL) << "Cannot divide by 0";
7031 return nullptr;
7032 } else {
7033 return RegisterIntExpr(
7034 MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
7035 // TODO(user) : implement special case.
7036 }
7037}
7038
7039Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
7040 if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
7041 Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
7042 }
7043 return RevAlloc(new IntAbsConstraint(this, var, abs_var));
7044}
7045
7047 CHECK_EQ(this, e->solver());
7048 if (e->Min() >= 0) {
7049 return e;
7050 } else if (e->Max() <= 0) {
7051 return MakeOpposite(e);
7052 }
7053 IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
7054 if (result == nullptr) {
7055 int64_t coefficient = 1;
7056 IntExpr* expr = nullptr;
7057 if (IsProduct(e, &expr, &coefficient)) {
7058 result = MakeProd(MakeAbs(expr), std::abs(coefficient));
7059 } else {
7060 result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
7061 }
7062 Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7063 }
7064 return result;
7065}
7066
7068 CHECK_EQ(this, expr->solver());
7069 if (expr->Bound()) {
7070 const int64_t v = expr->Min();
7071 return MakeIntConst(v * v);
7072 }
7073 IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7074 if (result == nullptr) {
7075 if (expr->Min() >= 0) {
7076 result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7077 } else {
7078 result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7079 }
7080 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7081 }
7082 return result;
7083}
7084
7085IntExpr* Solver::MakePower(IntExpr* const expr, int64_t n) {
7086 CHECK_EQ(this, expr->solver());
7087 CHECK_GE(n, 0);
7088 if (expr->Bound()) {
7089 const int64_t v = expr->Min();
7090 if (v >= OverflowLimit(n)) { // Overflow.
7091 return MakeIntConst(std::numeric_limits<int64_t>::max());
7092 }
7093 return MakeIntConst(IntPower(v, n));
7094 }
7095 switch (n) {
7096 case 0:
7097 return MakeIntConst(1);
7098 case 1:
7099 return expr;
7100 case 2:
7101 return MakeSquare(expr);
7102 default: {
7103 IntExpr* result = nullptr;
7104 if (n % 2 == 0) { // even.
7105 if (expr->Min() >= 0) {
7106 result =
7107 RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7108 } else {
7109 result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7110 }
7111 } else {
7112 result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7113 }
7114 return result;
7115 }
7116 }
7117}
7118
7119IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7120 CHECK_EQ(this, left->solver());
7121 CHECK_EQ(this, right->solver());
7122 if (left->Bound()) {
7123 return MakeMin(right, left->Min());
7124 }
7125 if (right->Bound()) {
7126 return MakeMin(left, right->Min());
7127 }
7128 if (left->Min() >= right->Max()) {
7129 return right;
7130 }
7131 if (right->Min() >= left->Max()) {
7132 return left;
7133 }
7134 return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7135}
7136
7137IntExpr* Solver::MakeMin(IntExpr* const expr, int64_t value) {
7138 CHECK_EQ(this, expr->solver());
7139 if (value <= expr->Min()) {
7140 return MakeIntConst(value);
7141 }
7142 if (expr->Bound()) {
7143 return MakeIntConst(std::min(expr->Min(), value));
7144 }
7145 if (expr->Max() <= value) {
7146 return expr;
7147 }
7148 return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7149}
7150
7151IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7152 return MakeMin(expr, static_cast<int64_t>(value));
7153}
7154
7155IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7156 CHECK_EQ(this, left->solver());
7157 CHECK_EQ(this, right->solver());
7158 if (left->Bound()) {
7159 return MakeMax(right, left->Min());
7160 }
7161 if (right->Bound()) {
7162 return MakeMax(left, right->Min());
7163 }
7164 if (left->Min() >= right->Max()) {
7165 return left;
7166 }
7167 if (right->Min() >= left->Max()) {
7168 return right;
7169 }
7170 return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7171}
7172
7173IntExpr* Solver::MakeMax(IntExpr* const expr, int64_t value) {
7174 CHECK_EQ(this, expr->solver());
7175 if (expr->Bound()) {
7176 return MakeIntConst(std::max(expr->Min(), value));
7177 }
7178 if (value <= expr->Min()) {
7179 return expr;
7180 }
7181 if (expr->Max() <= value) {
7182 return MakeIntConst(value);
7183 }
7184 return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7185}
7186
7187IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7188 return MakeMax(expr, static_cast<int64_t>(value));
7189}
7190
7192 int64_t early_date, int64_t late_date,
7193 int64_t late_cost) {
7194 return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7195 this, expr, early_cost, early_date, late_date, late_cost)));
7196}
7197
7199 int64_t fixed_charge, int64_t step) {
7200 if (step == 0) {
7201 if (fixed_charge == 0) {
7202 return MakeIntConst(int64_t{0});
7203 } else {
7204 return RegisterIntExpr(
7205 RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7206 }
7207 } else if (step == 1) {
7208 return RegisterIntExpr(
7209 RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7210 } else {
7211 return RegisterIntExpr(
7212 RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7213 }
7214 // TODO(user) : benchmark with virtualization of
7215 // PosIntDivDown and PosIntDivUp - or function pointers.
7216}
7217
7218// ----- Piecewise Linear -----
7219
7221 public:
7223 const PiecewiseLinearFunction& f)
7224 : BaseIntExpr(solver), expr_(expr), f_(f) {}
7226 int64_t Min() const override {
7227 return f_.GetMinimum(expr_->Min(), expr_->Max());
7228 }
7229 void SetMin(int64_t m) override {
7230 const auto& range =
7231 f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7232 expr_->SetRange(range.first, range.second);
7233 }
7234
7235 int64_t Max() const override {
7236 return f_.GetMaximum(expr_->Min(), expr_->Max());
7237 }
7238
7239 void SetMax(int64_t m) override {
7240 const auto& range =
7241 f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7242 expr_->SetRange(range.first, range.second);
7243 }
7244
7245 void SetRange(int64_t l, int64_t u) override {
7246 const auto& range =
7247 f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7248 expr_->SetRange(range.first, range.second);
7249 }
7250 std::string name() const override {
7251 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7252 f_.DebugString());
7253 }
7254
7255 std::string DebugString() const override {
7256 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7257 f_.DebugString());
7258 }
7259
7260 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7261
7262 void Accept(ModelVisitor* const visitor) const override {
7263 // TODO(user): Implement visitor.
7264 }
7265
7266 private:
7267 IntExpr* const expr_;
7268 const PiecewiseLinearFunction f_;
7269};
7270
7275
7276// ----- Conditional Expression -----
7277
7279 IntExpr* const expr,
7280 int64_t unperformed_value) {
7281 if (condition->Min() == 1) {
7282 return expr;
7283 } else if (condition->Max() == 0) {
7284 return MakeIntConst(unperformed_value);
7285 } else {
7286 IntExpr* cache = Cache()->FindExprExprConstantExpression(
7287 condition, expr, unperformed_value,
7289 if (cache == nullptr) {
7290 cache = RevAlloc(
7291 new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7292 Cache()->InsertExprExprConstantExpression(
7293 cache, condition, expr, unperformed_value,
7295 }
7296 return cache;
7297 }
7298}
7299
7300// ----- Modulo -----
7301
7302IntExpr* Solver::MakeModulo(IntExpr* const x, int64_t mod) {
7303 IntVar* const result =
7304 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7305 if (mod >= 0) {
7306 AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7307 } else {
7308 AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7309 }
7310 return result;
7311}
7312
7314 if (mod->Bound()) {
7315 return MakeModulo(x, mod->Min());
7316 }
7317 IntVar* const result =
7318 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7319 AddConstraint(MakeLess(result, MakeAbs(mod)));
7321 return result;
7322}
7323
7324// --------- IntVar ---------
7325
7326int IntVar::VarType() const { return UNSPECIFIED; }
7327
7328void IntVar::RemoveValues(const std::vector<int64_t>& values) {
7329 // TODO(user): Check and maybe inline this code.
7330 const int size = values.size();
7331 DCHECK_GE(size, 0);
7332 switch (size) {
7333 case 0: {
7334 return;
7335 }
7336 case 1: {
7337 RemoveValue(values[0]);
7338 return;
7339 }
7340 case 2: {
7341 RemoveValue(values[0]);
7342 RemoveValue(values[1]);
7343 return;
7344 }
7345 case 3: {
7346 RemoveValue(values[0]);
7347 RemoveValue(values[1]);
7348 RemoveValue(values[2]);
7349 return;
7350 }
7351 default: {
7352 // 4 values, let's start doing some more clever things.
7353 // TODO(user) : Sort values!
7354 int start_index = 0;
7355 int64_t new_min = Min();
7356 if (values[start_index] <= new_min) {
7357 while (start_index < size - 1 &&
7358 values[start_index + 1] == values[start_index] + 1) {
7359 new_min = values[start_index + 1] + 1;
7360 start_index++;
7361 }
7362 }
7363 int end_index = size - 1;
7364 int64_t new_max = Max();
7365 if (values[end_index] >= new_max) {
7366 while (end_index > start_index + 1 &&
7367 values[end_index - 1] == values[end_index] - 1) {
7368 new_max = values[end_index - 1] - 1;
7369 end_index--;
7370 }
7371 }
7372 SetRange(new_min, new_max);
7373 for (int i = start_index; i <= end_index; ++i) {
7374 RemoveValue(values[i]);
7375 }
7376 }
7377 }
7378}
7379
7380void IntVar::Accept(ModelVisitor* const visitor) const {
7381 IntExpr* const casted = solver()->CastExpression(this);
7382 visitor->VisitIntegerVariable(this, casted);
7383}
7384
7385void IntVar::SetValues(const std::vector<int64_t>& values) {
7386 switch (values.size()) {
7387 case 0: {
7388 solver()->Fail();
7389 break;
7390 }
7391 case 1: {
7392 SetValue(values.back());
7393 break;
7394 }
7395 case 2: {
7396 if (Contains(values[0])) {
7397 if (Contains(values[1])) {
7398 const int64_t l = std::min(values[0], values[1]);
7399 const int64_t u = std::max(values[0], values[1]);
7400 SetRange(l, u);
7401 if (u > l + 1) {
7402 RemoveInterval(l + 1, u - 1);
7403 }
7404 } else {
7405 SetValue(values[0]);
7406 }
7407 } else {
7408 SetValue(values[1]);
7409 }
7410 break;
7411 }
7412 default: {
7413 // TODO(user): use a clean and safe SortedUniqueCopy() class
7414 // that uses a global, static shared (and locked) storage.
7415 // TODO(user): [optional] consider porting
7416 // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7417 // existing base/stl_util.h and using it here.
7418 // TODO(user): We could filter out values not in the var.
7419 std::vector<int64_t>& tmp = solver()->tmp_vector_;
7420 tmp.clear();
7421 tmp.insert(tmp.end(), values.begin(), values.end());
7422 std::sort(tmp.begin(), tmp.end());
7423 tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7424 const int size = tmp.size();
7425 const int64_t vmin = Min();
7426 const int64_t vmax = Max();
7427 int first = 0;
7428 int last = size - 1;
7429 if (tmp.front() > vmax || tmp.back() < vmin) {
7430 solver()->Fail();
7431 }
7432 // TODO(user) : We could find the first position >= vmin by dichotomy.
7433 while (tmp[first] < vmin || !Contains(tmp[first])) {
7434 ++first;
7435 if (first > last || tmp[first] > vmax) {
7436 solver()->Fail();
7437 }
7438 }
7439 while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7440 // Note that last >= first implies tmp[last] >= vmin.
7441 --last;
7442 }
7443 DCHECK_GE(last, first);
7444 SetRange(tmp[first], tmp[last]);
7445 while (first < last) {
7446 const int64_t start = tmp[first] + 1;
7447 const int64_t end = tmp[first + 1] - 1;
7448 if (start <= end) {
7449 RemoveInterval(start, end);
7450 }
7451 first++;
7452 }
7453 }
7454 }
7455}
7456// ---------- BaseIntExpr ---------
7457
7458void LinkVarExpr(Solver* s, IntExpr* expr, IntVar* var) {
7459 if (!var->Bound()) {
7460 if (var->VarType() == DOMAIN_INT_VAR) {
7461 DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7463 s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7464 } else {
7465 s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7466 expr);
7467 }
7468 }
7469}
7470
7472 if (var_ == nullptr) {
7473 solver()->SaveValue(reinterpret_cast<void**>(&var_));
7474 var_ = CastToVar();
7475 }
7476 return var_;
7477}
7478
7480 int64_t vmin, vmax;
7481 Range(&vmin, &vmax);
7482 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7483 LinkVarExpr(solver(), this, var);
7484 return var;
7485}
7486
7487// Discovery methods
7488bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7489 IntExpr** const right) {
7490 if (expr->IsVar()) {
7491 IntVar* const expr_var = expr->Var();
7492 expr = CastExpression(expr_var);
7493 }
7494 // This is a dynamic cast to check the type of expr.
7495 // It returns nullptr is expr is not a subclass of SubIntExpr.
7496 SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7497 if (sub_expr != nullptr) {
7498 *left = sub_expr->left();
7499 *right = sub_expr->right();
7500 return true;
7501 }
7502 return false;
7503}
7504
7505bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7506 bool* is_negated) const {
7507 if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7508 *inner_var = expr->Var();
7509 *is_negated = false;
7510 return true;
7511 } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7512 SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7513 if (sub_var != nullptr && sub_var->Constant() == 1 &&
7514 sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7515 *is_negated = true;
7516 *inner_var = sub_var->SubVar();
7517 return true;
7518 }
7519 }
7520 return false;
7521}
7522
7523bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7524 int64_t* coefficient) {
7525 if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7526 TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7527 *coefficient = var->Constant();
7528 *inner_expr = var->SubVar();
7529 return true;
7530 } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7531 TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7532 *coefficient = prod->Constant();
7533 *inner_expr = prod->Expr();
7534 return true;
7535 }
7536 *inner_expr = expr;
7537 *coefficient = 1;
7538 return false;
7539}
7540
7541} // namespace operations_research
IntVar * Var() override
Creates a variable from the expression.
bool Contains(int64_t v) const override
IntVarIterator * MakeDomainIterator(bool reversible) const override
IntVar * IsGreaterOrEqual(int64_t constant) override
IntVarIterator * MakeHoleIterator(bool reversible) const override
--— Misc --—
IntVar * IsEqual(int64_t constant) override
IsEqual.
void RemoveValue(int64_t v) override
This method removes the value 'v' from the domain of the variable.
IntVar * IsDifferent(int64_t constant) override
SimpleRevFIFO< Demon * > delayed_bound_demons_
void SetMax(int64_t m) override
void SetRange(int64_t mi, int64_t ma) override
This method sets both the min and the max of the expression.
void RemoveInterval(int64_t l, int64_t u) override
void SetMin(int64_t m) override
SimpleRevFIFO< Demon * > bound_demons_
uint64_t Size() const override
This method returns the number of values in the domain of the variable.
static const int kUnboundBooleanVarValue
--— Boolean variable --—
std::string DebugString() const override
void WhenBound(Demon *d) override
IntVar * IsLessOrEqual(int64_t constant) override
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
virtual Solver::DemonPriority priority() const
---------------— Demon class -------------—
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetMax(int64_t m)=0
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual int64_t Min() const =0
virtual void SetMin(int64_t m)=0
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
virtual void Range(int64_t *l, int64_t *u)
IntVar * VarWithName(const std::string &name)
-------— IntExpr -------—
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
virtual void WhenBound(Demon *d)=0
virtual void WhenDomain(Demon *d)=0
virtual void SetValues(const std::vector< int64_t > &values)
This method intersects the current domain with the values in the array.
void Accept(ModelVisitor *visitor) const override
Accepts the given visitor.
virtual IntVar * IsDifferent(int64_t constant)=0
IntVar(Solver *s)
-------— IntVar -------—
virtual int64_t OldMax() const =0
Returns the previous max.
virtual IntVar * IsLessOrEqual(int64_t constant)=0
virtual bool Contains(int64_t v) const =0
virtual int64_t Value() const =0
virtual int VarType() const
------— IntVar ------—
virtual void RemoveValue(int64_t v)=0
This method removes the value 'v' from the domain of the variable.
virtual uint64_t Size() const =0
This method returns the number of values in the domain of the variable.
virtual IntVar * IsGreaterOrEqual(int64_t constant)=0
virtual IntVar * IsEqual(int64_t constant)=0
IsEqual.
virtual void RemoveInterval(int64_t l, int64_t u)=0
virtual int64_t OldMin() const =0
Returns the previous min.
virtual void RemoveValues(const std::vector< int64_t > &values)
This method remove the values from the domain of the variable.
static int64_t GCD64(int64_t x, int64_t y)
Definition mathutil.h:107
virtual void VisitIntegerVariable(const IntVar *variable, IntExpr *delegate)
void SetRange(int64_t l, int64_t u) override
This method sets both the min and the max of the expression.
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
std::string name() const override
Object naming.
std::string DebugString() const override
virtual std::string name() const
Object naming.
For the time being, Solver is neither MT_SAFE nor MT_HOT.
IntExpr * MakeDiv(IntExpr *expr, int64_t value)
expr / value (integer division)
IntVar * MakeBoolVar(const std::string &name)
MakeBoolVar will create a variable with a {0, 1} domain.
Constraint * MakeNonEquality(IntExpr *left, IntExpr *right)
left != right
Definition range_cst.cc:570
IntExpr * MakeMax(const std::vector< IntVar * > &vars)
std::max(vars)
IntExpr * MakeDifference(IntExpr *left, IntExpr *right)
left - right
IntVar * MakeBoolVar()
MakeBoolVar will create a variable with a {0, 1} domain.
IntExpr * MakeSum(IntExpr *left, IntExpr *right)
--— Integer Expressions --—
IntExpr * MakeMin(const std::vector< IntVar * > &vars)
std::min(vars)
void Fail()
Abandon the current branch in the search tree. A backtrack will follow.
@ VAR_PRIORITY
VAR_PRIORITY is between DELAYED_PRIORITY and NORMAL_PRIORITY.
@ OUTSIDE_SEARCH
Before search, after search.
IntVar * RegisterIntVar(IntVar *var)
Registers a new IntVar and wraps it inside a TraceIntVar if necessary.
Definition trace.cc:860
bool IsBooleanVar(IntExpr *expr, IntVar **inner_var, bool *is_negated) const
Constraint * MakeLess(IntExpr *left, IntExpr *right)
left < right
Definition range_cst.cc:552
ModelCache * Cache() const
Returns the cache of the model.
IntExpr * MakeOpposite(IntExpr *expr)
-expr
Constraint * MakeAbsEquality(IntVar *var, IntVar *abs_var)
Creates the constraint abs(var) == abs_var.
void MakeBoolVarArray(int var_count, const std::string &name, std::vector< IntVar * > *vars)
IntExpr * MakePiecewiseLinearExpr(IntExpr *expr, const PiecewiseLinearFunction &f)
IntExpr * RegisterIntExpr(IntExpr *expr)
Registers a new IntExpr and wraps it inside a TraceIntExpr if necessary.
Definition trace.cc:848
IntExpr * MakeModulo(IntExpr *x, int64_t mod)
Modulo expression x % mod (with the python convention for modulo).
Constraint * MakeGreater(IntExpr *left, IntExpr *right)
left > right
Definition range_cst.cc:566
IntExpr * MakeAbs(IntExpr *expr)
expr
IntExpr * MakeSquare(IntExpr *expr)
expr * expr
Constraint * MakeBetweenCt(IntExpr *expr, int64_t l, int64_t u)
--— Between and related constraints --—
Definition expr_cst.cc:929
void MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax, const std::string &name, std::vector< IntVar * > *vars)
IntVar * MakeIntVar(int64_t min, int64_t max, const std::string &name)
--— Int Variables and Constants --—
bool IsProduct(IntExpr *expr, IntExpr **inner_expr, int64_t *coefficient)
IntExpr * MakeConvexPiecewiseExpr(IntExpr *expr, int64_t early_cost, int64_t early_date, int64_t late_date, int64_t late_cost)
Convex piecewise function.
IntExpr * MakePower(IntExpr *expr, int64_t n)
expr ^ n (n > 0)
IntExpr * MakeConditionalExpression(IntVar *condition, IntExpr *expr, int64_t unperformed_value)
Conditional Expr condition ? expr : unperformed_value.
IntExpr * MakeProd(IntExpr *left, IntExpr *right)
left * right
IntVar * MakeIntConst(int64_t val, const std::string &name)
IntConst will create a constant expression.
IntExpr * MakeSemiContinuousExpr(IntExpr *expr, int64_t fixed_charge, int64_t step)
void AddCastConstraint(CastConstraint *constraint, IntVar *target_var, IntExpr *expr)
ABSL_FLAG(bool, cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition map_util.h:94
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition stl_util.h:55
For infeasible and unbounded see Not checked if options check_solutions_if_inf_or_unbounded and the If options first_solution_only is false
problem is infeasible or unbounded (default).
Definition matchers.h:468
std::pair< double, double > Range
A range of values, first is the minimum, second is the maximum.
Definition statistics.h:27
dual_gradient T(y - `dual_solution`) class DiagonalTrustRegionProblemFromQp
std::function< int64_t(const Model &)> Value(IntegerVariable v)
This checks that the variable is fixed.
Definition integer.h:1601
In SWIG mode, we don't want anything besides these top-level includes.
int64_t SubOverflows(int64_t x, int64_t y)
int64_t UnsafeMostSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
int64_t CapAdd(int64_t x, int64_t y)
void RestoreBoolValue(IntVar *var)
--— Trail --—
int64_t UnsafeLeastSignificantBitPosition64(const uint64_t *bitset, uint64_t start, uint64_t end)
int64_t CapSub(int64_t x, int64_t y)
Constraint * SetIsGreaterOrEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
ClosedInterval::Iterator end(ClosedInterval interval)
Constraint * SetIsEqual(IntVar *const var, absl::Span< const int64_t > values, const std::vector< IntVar * > &vars)
bool AddOverflows(int64_t x, int64_t y)
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
--— Exported Methods for Unit Tests --—
static const uint64_t kAllBits64
Basic bit operations.
Definition bitset.h:36
void LinkVarExpr(Solver *s, IntExpr *expr, IntVar *var)
--— IntExprElement --—
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
int64_t CapProd(int64_t x, int64_t y)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
--— Vector manipulations --—
Definition utilities.cc:829
uint64_t OneRange64(uint64_t s, uint64_t e)
Returns a word with bits from s to e set.
Definition bitset.h:288
uint32_t BitPos64(uint64_t pos)
Bit operators used to manipulates bitsets.
Definition bitset.h:333
uint64_t BitCountRange64(const uint64_t *bitset, uint64_t start, uint64_t end)
Returns the number of bits set in bitset between positions start and end.
uint64_t BitCount64(uint64_t n)
Returns the number of bits set in n.
Definition bitset.h:45
bool IsBitSet64(const uint64_t *const bitset, uint64_t pos)
Returns true if the bit pos is set in bitset.
Definition bitset.h:349
uint64_t OneBit64(int pos)
Returns a word with only bit pos set.
Definition bitset.h:41
uint64_t BitOffset64(uint64_t pos)
Returns the word number corresponding to bit number pos.
Definition bitset.h:337
int64_t PosIntDivDown(int64_t e, int64_t v)
uint64_t BitLength64(uint64_t size)
Returns the number of words needed to store size bits.
Definition bitset.h:341
int LeastSignificantBitPosition64(uint64_t n)
Definition bitset.h:130
void CleanVariableOnFail(IntVar *var)
---------------— Queue class ---------------—
int64_t CapOpp(int64_t v)
Note(user): -kint64min != kint64max, but kint64max == ~kint64min.
int MostSignificantBitPosition64(uint64_t n)
Definition bitset.h:234
int64_t PosIntDivUp(int64_t e, int64_t v)
STL namespace.
false true
Definition numbers.cc:223
trees with all degrees equal w the current value of w
bool Next()