Google OR-Tools v9.15
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
expr_cst.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14// Expression constraints
15
16#include <algorithm>
17#include <cstddef>
18#include <cstdint>
19#include <limits>
20#include <set>
21#include <string>
22#include <utility>
23#include <vector>
24
25#include "absl/strings/str_format.h"
26#include "absl/strings/str_join.h"
30#include "ortools/base/types.h"
35
36ABSL_FLAG(int, cache_initial_size, 1024,
37 "Initial size of the array of the hash "
38 "table of caches for objects of type Var(x == 3)");
39
40namespace operations_research {
41
42//-----------------------------------------------------------------------------
43// Equality
44
45namespace {
46class EqualityExprCst : public Constraint {
47 public:
48 EqualityExprCst(Solver* s, IntExpr* e, int64_t v);
49 ~EqualityExprCst() override {}
50 void Post() override;
51 void InitialPropagate() override;
52 IntVar* Var() override {
53 return solver()->MakeIsEqualCstVar(expr_->Var(), value_);
54 }
55 std::string DebugString() const override;
56
57 void Accept(ModelVisitor* const visitor) const override {
58 visitor->BeginVisitConstraint(ModelVisitor::kEquality, this);
59 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
60 expr_);
61 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
62 visitor->EndVisitConstraint(ModelVisitor::kEquality, this);
63 }
64
65 private:
66 IntExpr* const expr_;
67 int64_t value_;
68};
69
70EqualityExprCst::EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v)
71 : Constraint(s), expr_(e), value_(v) {}
72
73void EqualityExprCst::Post() {
74 if (!expr_->IsVar()) {
75 Demon* d = solver()->MakeConstraintInitialPropagateCallback(this);
76 expr_->WhenRange(d);
77 }
78}
79
80void EqualityExprCst::InitialPropagate() { expr_->SetValue(value_); }
81
82std::string EqualityExprCst::DebugString() const {
83 return absl::StrFormat("(%s == %d)", expr_->DebugString(), value_);
84}
85} // namespace
86
87Constraint* Solver::MakeEquality(IntExpr* const e, int64_t v) {
88 CHECK_EQ(this, e->solver());
89 IntExpr* left = nullptr;
90 IntExpr* right = nullptr;
91 if (IsADifference(e, &left, &right)) {
92 return MakeEquality(left, MakeSum(right, v));
93 } else if (e->IsVar() && !e->Var()->Contains(v)) {
94 return MakeFalseConstraint();
95 } else if (e->Min() == e->Max() && e->Min() == v) {
96 return MakeTrueConstraint();
97 } else {
98 return RevAlloc(new EqualityExprCst(this, e, v));
99 }
100}
101
102Constraint* Solver::MakeEquality(IntExpr* const e, int v) {
103 CHECK_EQ(this, e->solver());
104 IntExpr* left = nullptr;
105 IntExpr* right = nullptr;
106 if (IsADifference(e, &left, &right)) {
107 return MakeEquality(left, MakeSum(right, v));
108 } else if (e->IsVar() && !e->Var()->Contains(v)) {
109 return MakeFalseConstraint();
110 } else if (e->Min() == e->Max() && e->Min() == v) {
111 return MakeTrueConstraint();
112 } else {
113 return RevAlloc(new EqualityExprCst(this, e, v));
114 }
115}
116
117//-----------------------------------------------------------------------------
118// Greater or equal constraint
119
120namespace {
121class GreaterEqExprCst : public Constraint {
122 public:
123 GreaterEqExprCst(Solver* s, IntExpr* e, int64_t v);
124 ~GreaterEqExprCst() override {}
125 void Post() override;
126 void InitialPropagate() override;
127 std::string DebugString() const override;
128 IntVar* Var() override {
129 return solver()->MakeIsGreaterOrEqualCstVar(expr_->Var(), value_);
130 }
131
132 void Accept(ModelVisitor* const visitor) const override {
133 visitor->BeginVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
134 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
135 expr_);
136 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
137 visitor->EndVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
138 }
139
140 private:
141 IntExpr* const expr_;
142 int64_t value_;
143 Demon* demon_;
144};
145
146GreaterEqExprCst::GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
147 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
148
149void GreaterEqExprCst::Post() {
150 if (!expr_->IsVar() && expr_->Min() < value_) {
151 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
152 expr_->WhenRange(demon_);
153 } else {
154 // Let's clean the demon in case the constraint is posted during search.
155 demon_ = nullptr;
156 }
157}
158
159void GreaterEqExprCst::InitialPropagate() {
160 expr_->SetMin(value_);
161 if (demon_ != nullptr && expr_->Min() >= value_) {
162 demon_->inhibit(solver());
163 }
164}
165
166std::string GreaterEqExprCst::DebugString() const {
167 return absl::StrFormat("(%s >= %d)", expr_->DebugString(), value_);
168}
169} // namespace
170
171Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int64_t v) {
172 CHECK_EQ(this, e->solver());
173 if (e->Min() >= v) {
174 return MakeTrueConstraint();
175 } else if (e->Max() < v) {
176 return MakeFalseConstraint();
177 } else {
178 return RevAlloc(new GreaterEqExprCst(this, e, v));
179 }
180}
181
182Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int v) {
183 CHECK_EQ(this, e->solver());
184 if (e->Min() >= v) {
185 return MakeTrueConstraint();
186 } else if (e->Max() < v) {
187 return MakeFalseConstraint();
188 } else {
189 return RevAlloc(new GreaterEqExprCst(this, e, v));
190 }
191}
192
193Constraint* Solver::MakeGreater(IntExpr* const e, int64_t v) {
194 CHECK_EQ(this, e->solver());
195 if (e->Min() > v) {
196 return MakeTrueConstraint();
197 } else if (e->Max() <= v) {
198 return MakeFalseConstraint();
199 } else {
200 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
201 }
202}
203
204Constraint* Solver::MakeGreater(IntExpr* const e, int v) {
205 CHECK_EQ(this, e->solver());
206 if (e->Min() > v) {
207 return MakeTrueConstraint();
208 } else if (e->Max() <= v) {
209 return MakeFalseConstraint();
210 } else {
211 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
212 }
213}
214
215//-----------------------------------------------------------------------------
216// Less or equal constraint
217
218namespace {
219class LessEqExprCst : public Constraint {
220 public:
221 LessEqExprCst(Solver* s, IntExpr* e, int64_t v);
222 ~LessEqExprCst() override {}
223 void Post() override;
224 void InitialPropagate() override;
225 std::string DebugString() const override;
226 IntVar* Var() override {
227 return solver()->MakeIsLessOrEqualCstVar(expr_->Var(), value_);
228 }
229 void Accept(ModelVisitor* const visitor) const override {
230 visitor->BeginVisitConstraint(ModelVisitor::kLessOrEqual, this);
231 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
232 expr_);
233 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
234 visitor->EndVisitConstraint(ModelVisitor::kLessOrEqual, this);
235 }
236
237 private:
238 IntExpr* const expr_;
239 int64_t value_;
240 Demon* demon_;
241};
242
243LessEqExprCst::LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
244 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
245
246void LessEqExprCst::Post() {
247 if (!expr_->IsVar() && expr_->Max() > value_) {
248 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
249 expr_->WhenRange(demon_);
250 } else {
251 // Let's clean the demon in case the constraint is posted during search.
252 demon_ = nullptr;
253 }
254}
255
256void LessEqExprCst::InitialPropagate() {
257 expr_->SetMax(value_);
258 if (demon_ != nullptr && expr_->Max() <= value_) {
259 demon_->inhibit(solver());
260 }
261}
262
263std::string LessEqExprCst::DebugString() const {
264 return absl::StrFormat("(%s <= %d)", expr_->DebugString(), value_);
265}
266} // namespace
267
268Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int64_t v) {
269 CHECK_EQ(this, e->solver());
270 if (e->Max() <= v) {
272 } else if (e->Min() > v) {
273 return MakeFalseConstraint();
274 } else {
275 return RevAlloc(new LessEqExprCst(this, e, v));
276 }
277}
278
279Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int v) {
280 CHECK_EQ(this, e->solver());
281 if (e->Max() <= v) {
283 } else if (e->Min() > v) {
284 return MakeFalseConstraint();
285 } else {
286 return RevAlloc(new LessEqExprCst(this, e, v));
287 }
288}
289
290Constraint* Solver::MakeLess(IntExpr* const e, int64_t v) {
291 CHECK_EQ(this, e->solver());
292 if (e->Max() < v) {
294 } else if (e->Min() >= v) {
295 return MakeFalseConstraint();
296 } else {
297 return RevAlloc(new LessEqExprCst(this, e, v - 1));
298 }
299}
300
301Constraint* Solver::MakeLess(IntExpr* const e, int v) {
302 CHECK_EQ(this, e->solver());
303 if (e->Max() < v) {
305 } else if (e->Min() >= v) {
306 return MakeFalseConstraint();
307 } else {
308 return RevAlloc(new LessEqExprCst(this, e, v - 1));
309 }
310}
311
312//-----------------------------------------------------------------------------
313// Different constraints
314
315namespace {
316class DiffCst : public Constraint {
317 public:
318 DiffCst(Solver* s, IntVar* var, int64_t value);
319 ~DiffCst() override {}
320 void Post() override {}
321 void InitialPropagate() override;
322 void BoundPropagate();
323 std::string DebugString() const override;
324 IntVar* Var() override {
325 return solver()->MakeIsDifferentCstVar(var_, value_);
326 }
327 void Accept(ModelVisitor* const visitor) const override {
328 visitor->BeginVisitConstraint(ModelVisitor::kNonEqual, this);
329 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
330 var_);
331 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
332 visitor->EndVisitConstraint(ModelVisitor::kNonEqual, this);
333 }
334
335 private:
336 bool HasLargeDomain(IntVar* var);
337
338 IntVar* const var_;
339 int64_t value_;
340 Demon* demon_;
341};
342
343DiffCst::DiffCst(Solver* const s, IntVar* const var, int64_t value)
344 : Constraint(s), var_(var), value_(value), demon_(nullptr) {}
345
346void DiffCst::InitialPropagate() {
347 if (HasLargeDomain(var_)) {
348 demon_ = MakeConstraintDemon0(solver(), this, &DiffCst::BoundPropagate,
349 "BoundPropagate");
350 var_->WhenRange(demon_);
351 } else {
352 var_->RemoveValue(value_);
353 }
354}
355
356void DiffCst::BoundPropagate() {
357 const int64_t var_min = var_->Min();
358 const int64_t var_max = var_->Max();
359 if (var_min > value_ || var_max < value_) {
360 demon_->inhibit(solver());
361 } else if (var_min == value_) {
362 var_->SetMin(CapAdd(value_, 1));
363 } else if (var_max == value_) {
364 var_->SetMax(CapSub(value_, 1));
365 } else if (!HasLargeDomain(var_)) {
366 demon_->inhibit(solver());
367 var_->RemoveValue(value_);
368 }
369}
370
371std::string DiffCst::DebugString() const {
372 return absl::StrFormat("(%s != %d)", var_->DebugString(), value_);
373}
374
375bool DiffCst::HasLargeDomain(IntVar* var) {
376 return CapSub(var->Max(), var->Min()) > 0xFFFFFF;
377}
378} // namespace
379
380Constraint* Solver::MakeNonEquality(IntExpr* const e, int64_t v) {
381 CHECK_EQ(this, e->solver());
382 IntExpr* left = nullptr;
383 IntExpr* right = nullptr;
384 if (IsADifference(e, &left, &right)) {
385 return MakeNonEquality(left, MakeSum(right, v));
386 } else if (e->IsVar() && !e->Var()->Contains(v)) {
387 return MakeTrueConstraint();
388 } else if (e->Bound() && e->Min() == v) {
389 return MakeFalseConstraint();
390 } else {
391 return RevAlloc(new DiffCst(this, e->Var(), v));
392 }
393}
394
395Constraint* Solver::MakeNonEquality(IntExpr* const e, int v) {
396 CHECK_EQ(this, e->solver());
397 IntExpr* left = nullptr;
398 IntExpr* right = nullptr;
399 if (IsADifference(e, &left, &right)) {
400 return MakeNonEquality(left, MakeSum(right, v));
401 } else if (e->IsVar() && !e->Var()->Contains(v)) {
402 return MakeTrueConstraint();
403 } else if (e->Bound() && e->Min() == v) {
404 return MakeFalseConstraint();
405 } else {
406 return RevAlloc(new DiffCst(this, e->Var(), v));
407 }
408}
409// ----- is_equal_cst Constraint -----
410
411namespace {
412class IsEqualCstCt : public CastConstraint {
413 public:
414 IsEqualCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
415 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
416 void Post() override {
417 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
418 var_->WhenDomain(demon_);
419 target_var_->WhenBound(demon_);
420 }
421 void InitialPropagate() override {
422 bool inhibit = var_->Bound();
423 int64_t u = var_->Contains(cst_);
424 int64_t l = inhibit ? u : 0;
425 target_var_->SetRange(l, u);
426 if (target_var_->Bound()) {
427 if (target_var_->Min() == 0) {
428 if (var_->Size() <= 0xFFFFFF) {
429 var_->RemoveValue(cst_);
430 inhibit = true;
431 }
432 } else {
433 var_->SetValue(cst_);
434 inhibit = true;
435 }
436 }
437 if (inhibit) {
438 demon_->inhibit(solver());
439 }
440 }
441 std::string DebugString() const override {
442 return absl::StrFormat("IsEqualCstCt(%s, %d, %s)", var_->DebugString(),
443 cst_, target_var_->DebugString());
444 }
445
446 void Accept(ModelVisitor* const visitor) const override {
447 visitor->BeginVisitConstraint(ModelVisitor::kIsEqual, this);
448 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
449 var_);
450 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
451 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
452 target_var_);
453 visitor->EndVisitConstraint(ModelVisitor::kIsEqual, this);
454 }
455
456 private:
457 IntVar* const var_;
458 int64_t cst_;
459 Demon* demon_;
460};
461} // namespace
462
463IntVar* Solver::MakeIsEqualCstVar(IntExpr* const var, int64_t value) {
464 IntExpr* left = nullptr;
465 IntExpr* right = nullptr;
466 if (IsADifference(var, &left, &right)) {
467 return MakeIsEqualVar(left, MakeSum(right, value));
468 }
469 if (CapSub(var->Max(), var->Min()) == 1) {
470 if (value == var->Min()) {
471 return MakeDifference(value + 1, var)->Var();
472 } else if (value == var->Max()) {
473 return MakeSum(var, -value + 1)->Var();
474 } else {
475 return MakeIntConst(0);
476 }
477 }
478 if (var->IsVar()) {
479 return var->Var()->IsEqual(value);
480 } else {
481 IntVar* const boolvar =
482 MakeBoolVar(absl::StrFormat("Is(%s == %d)", var->DebugString(), value));
483 AddConstraint(MakeIsEqualCstCt(var, value, boolvar));
484 return boolvar;
485 }
486}
487
488Constraint* Solver::MakeIsEqualCstCt(IntExpr* const var, int64_t value,
489 IntVar* const boolvar) {
490 CHECK_EQ(this, var->solver());
491 CHECK_EQ(this, boolvar->solver());
492 if (value == var->Min()) {
493 if (CapSub(var->Max(), var->Min()) == 1) {
494 return MakeEquality(MakeDifference(value + 1, var), boolvar);
495 }
496 return MakeIsLessOrEqualCstCt(var, value, boolvar);
497 }
498 if (value == var->Max()) {
499 if (CapSub(var->Max(), var->Min()) == 1) {
500 return MakeEquality(MakeSum(var, -value + 1), boolvar);
501 }
502 return MakeIsGreaterOrEqualCstCt(var, value, boolvar);
503 }
504 if (boolvar->Bound()) {
505 if (boolvar->Min() == 0) {
506 return MakeNonEquality(var, value);
507 } else {
508 return MakeEquality(var, value);
509 }
510 }
511 // TODO(user) : what happens if the constraint is not posted?
512 // The cache becomes tainted.
513 model_cache_->InsertExprConstantExpression(
514 boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_EQUAL);
515 IntExpr* left = nullptr;
516 IntExpr* right = nullptr;
517 if (IsADifference(var, &left, &right)) {
518 return MakeIsEqualCt(left, MakeSum(right, value), boolvar);
519 } else {
520 return RevAlloc(new IsEqualCstCt(this, var->Var(), value, boolvar));
521 }
522}
523
524// ----- is_diff_cst Constraint -----
525
526namespace {
527class IsDiffCstCt : public CastConstraint {
528 public:
529 IsDiffCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
530 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
531
532 void Post() override {
533 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
534 var_->WhenDomain(demon_);
535 target_var_->WhenBound(demon_);
536 }
537
538 void InitialPropagate() override {
539 bool inhibit = var_->Bound();
540 int64_t l = 1 - var_->Contains(cst_);
541 int64_t u = inhibit ? l : 1;
542 target_var_->SetRange(l, u);
543 if (target_var_->Bound()) {
544 if (target_var_->Min() == 1) {
545 if (var_->Size() <= 0xFFFFFF) {
546 var_->RemoveValue(cst_);
547 inhibit = true;
548 }
549 } else {
550 var_->SetValue(cst_);
551 inhibit = true;
552 }
553 }
554 if (inhibit) {
555 demon_->inhibit(solver());
556 }
557 }
558
559 std::string DebugString() const override {
560 return absl::StrFormat("IsDiffCstCt(%s, %d, %s)", var_->DebugString(), cst_,
561 target_var_->DebugString());
562 }
563
564 void Accept(ModelVisitor* const visitor) const override {
565 visitor->BeginVisitConstraint(ModelVisitor::kIsDifferent, this);
566 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
567 var_);
568 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
569 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
570 target_var_);
571 visitor->EndVisitConstraint(ModelVisitor::kIsDifferent, this);
572 }
573
574 private:
575 IntVar* const var_;
576 int64_t cst_;
577 Demon* demon_;
578};
579} // namespace
580
581IntVar* Solver::MakeIsDifferentCstVar(IntExpr* const var, int64_t value) {
582 IntExpr* left = nullptr;
583 IntExpr* right = nullptr;
584 if (IsADifference(var, &left, &right)) {
585 return MakeIsDifferentVar(left, MakeSum(right, value));
586 }
587 return var->Var()->IsDifferent(value);
588}
589
590Constraint* Solver::MakeIsDifferentCstCt(IntExpr* const var, int64_t value,
591 IntVar* const boolvar) {
592 CHECK_EQ(this, var->solver());
593 CHECK_EQ(this, boolvar->solver());
594 if (value == var->Min()) {
595 return MakeIsGreaterOrEqualCstCt(var, value + 1, boolvar);
596 }
597 if (value == var->Max()) {
598 return MakeIsLessOrEqualCstCt(var, value - 1, boolvar);
599 }
600 if (var->IsVar() && !var->Var()->Contains(value)) {
601 return MakeEquality(boolvar, int64_t{1});
602 }
603 if (var->Bound() && var->Min() == value) {
604 return MakeEquality(boolvar, Zero());
605 }
606 if (boolvar->Bound()) {
607 if (boolvar->Min() == 0) {
608 return MakeEquality(var, value);
609 } else {
610 return MakeNonEquality(var, value);
611 }
612 }
613 model_cache_->InsertExprConstantExpression(
614 boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
615 IntExpr* left = nullptr;
616 IntExpr* right = nullptr;
617 if (IsADifference(var, &left, &right)) {
618 return MakeIsDifferentCt(left, MakeSum(right, value), boolvar);
619 } else {
620 return RevAlloc(new IsDiffCstCt(this, var->Var(), value, boolvar));
621 }
622}
623
624// ----- is_greater_equal_cst Constraint -----
625
626namespace {
627class IsGreaterEqualCstCt : public CastConstraint {
628 public:
629 IsGreaterEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
630 IntVar* const b)
631 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
632 void Post() override {
633 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
634 expr_->WhenRange(demon_);
635 target_var_->WhenBound(demon_);
636 }
637 void InitialPropagate() override {
638 bool inhibit = false;
639 int64_t u = expr_->Max() >= cst_;
640 int64_t l = expr_->Min() >= cst_;
641 target_var_->SetRange(l, u);
642 if (target_var_->Bound()) {
643 inhibit = true;
644 if (target_var_->Min() == 0) {
645 expr_->SetMax(cst_ - 1);
646 } else {
647 expr_->SetMin(cst_);
648 }
649 }
650 if (inhibit && ((target_var_->Max() == 0 && expr_->Max() < cst_) ||
651 (target_var_->Min() == 1 && expr_->Min() >= cst_))) {
652 // Can we safely inhibit? Sometimes an expression is not
653 // persistent, just monotonic.
654 demon_->inhibit(solver());
655 }
656 }
657 std::string DebugString() const override {
658 return absl::StrFormat("IsGreaterEqualCstCt(%s, %d, %s)",
659 expr_->DebugString(), cst_,
660 target_var_->DebugString());
661 }
662
663 void Accept(ModelVisitor* const visitor) const override {
664 visitor->BeginVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
665 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
666 expr_);
667 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
668 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
669 target_var_);
670 visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
671 }
672
673 private:
674 IntExpr* const expr_;
675 int64_t cst_;
676 Demon* demon_;
677};
678} // namespace
679
680IntVar* Solver::MakeIsGreaterOrEqualCstVar(IntExpr* const var, int64_t value) {
681 if (var->Min() >= value) {
682 return MakeIntConst(int64_t{1});
683 }
684 if (var->Max() < value) {
685 return MakeIntConst(int64_t{0});
686 }
687 if (var->IsVar()) {
688 return var->Var()->IsGreaterOrEqual(value);
689 } else {
690 IntVar* const boolvar =
691 MakeBoolVar(absl::StrFormat("Is(%s >= %d)", var->DebugString(), value));
692 AddConstraint(MakeIsGreaterOrEqualCstCt(var, value, boolvar));
693 return boolvar;
694 }
695}
696
697IntVar* Solver::MakeIsGreaterCstVar(IntExpr* const var, int64_t value) {
698 return MakeIsGreaterOrEqualCstVar(var, value + 1);
699}
700
702 IntVar* const boolvar) {
703 if (boolvar->Bound()) {
704 if (boolvar->Min() == 0) {
705 return MakeLess(var, value);
706 } else {
707 return MakeGreaterOrEqual(var, value);
708 }
709 }
710 CHECK_EQ(this, var->solver());
711 CHECK_EQ(this, boolvar->solver());
712 model_cache_->InsertExprConstantExpression(
714 return RevAlloc(new IsGreaterEqualCstCt(this, var, value, boolvar));
715}
716
717Constraint* Solver::MakeIsGreaterCstCt(IntExpr* const v, int64_t c,
718 IntVar* const b) {
719 return MakeIsGreaterOrEqualCstCt(v, c + 1, b);
720}
722// ----- is_lesser_equal_cst Constraint -----
723
724namespace {
725class IsLessEqualCstCt : public CastConstraint {
726 public:
727 IsLessEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
728 IntVar* const b)
729 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
730
731 void Post() override {
732 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
733 expr_->WhenRange(demon_);
734 target_var_->WhenBound(demon_);
735 }
736
737 void InitialPropagate() override {
738 bool inhibit = false;
739 int64_t u = expr_->Min() <= cst_;
740 int64_t l = expr_->Max() <= cst_;
741 target_var_->SetRange(l, u);
742 if (target_var_->Bound()) {
743 inhibit = true;
744 if (target_var_->Min() == 0) {
745 expr_->SetMin(cst_ + 1);
746 } else {
747 expr_->SetMax(cst_);
748 }
749 }
750 if (inhibit && ((target_var_->Max() == 0 && expr_->Min() > cst_) ||
751 (target_var_->Min() == 1 && expr_->Max() <= cst_))) {
752 // Can we safely inhibit? Sometimes an expression is not
753 // persistent, just monotonic.
754 demon_->inhibit(solver());
755 }
756 }
757
758 std::string DebugString() const override {
759 return absl::StrFormat("IsLessEqualCstCt(%s, %d, %s)", expr_->DebugString(),
760 cst_, target_var_->DebugString());
761 }
762
763 void Accept(ModelVisitor* const visitor) const override {
766 expr_);
769 target_var_);
771 }
772
773 private:
774 IntExpr* const expr_;
775 int64_t cst_;
776 Demon* demon_;
777};
778} // namespace
779
780IntVar* Solver::MakeIsLessOrEqualCstVar(IntExpr* const var, int64_t value) {
781 if (var->Max() <= value) {
782 return MakeIntConst(int64_t{1});
783 }
784 if (var->Min() > value) {
785 return MakeIntConst(int64_t{0});
786 }
787 if (var->IsVar()) {
788 return var->Var()->IsLessOrEqual(value);
789 } else {
790 IntVar* const boolvar =
791 MakeBoolVar(absl::StrFormat("Is(%s <= %d)", var->DebugString(), value));
792 AddConstraint(MakeIsLessOrEqualCstCt(var, value, boolvar));
793 return boolvar;
794 }
795}
796
797IntVar* Solver::MakeIsLessCstVar(IntExpr* const var, int64_t value) {
798 return MakeIsLessOrEqualCstVar(var, value - 1);
799}
800
802 IntVar* const boolvar) {
803 if (boolvar->Bound()) {
804 if (boolvar->Min() == 0) {
805 return MakeGreater(var, value);
806 } else {
807 return MakeLessOrEqual(var, value);
808 }
809 }
810 CHECK_EQ(this, var->solver());
811 CHECK_EQ(this, boolvar->solver());
812 model_cache_->InsertExprConstantExpression(
813 boolvar, var, value, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
814 return RevAlloc(new IsLessEqualCstCt(this, var, value, boolvar));
815}
816
817Constraint* Solver::MakeIsLessCstCt(IntExpr* const v, int64_t c,
818 IntVar* const b) {
819 return MakeIsLessOrEqualCstCt(v, c - 1, b);
820}
822// ----- BetweenCt -----
823
824namespace {
825class BetweenCt : public Constraint {
826 public:
827 BetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
828 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
829
830 void Post() override {
831 if (!expr_->IsVar()) {
832 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
833 expr_->WhenRange(demon_);
834 }
835 }
836
837 void InitialPropagate() override {
838 expr_->SetRange(min_, max_);
839 int64_t emin = 0;
840 int64_t emax = 0;
841 expr_->Range(&emin, &emax);
842 if (demon_ != nullptr && emin >= min_ && emax <= max_) {
843 demon_->inhibit(solver());
844 }
845 }
846
847 std::string DebugString() const override {
848 return absl::StrFormat("BetweenCt(%s, %d, %d)", expr_->DebugString(), min_,
849 max_);
850 }
851
852 void Accept(ModelVisitor* const visitor) const override {
856 expr_);
859 }
860
861 private:
862 IntExpr* const expr_;
863 int64_t min_;
864 int64_t max_;
865 Demon* demon_;
866};
867
868// ----- NonMember constraint -----
869
870class NotBetweenCt : public Constraint {
871 public:
872 NotBetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
873 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
874
875 void Post() override {
876 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
877 expr_->WhenRange(demon_);
878 }
879
880 void InitialPropagate() override {
881 int64_t emin = 0;
882 int64_t emax = 0;
883 expr_->Range(&emin, &emax);
884 if (emin >= min_) {
885 expr_->SetMin(max_ + 1);
886 } else if (emax <= max_) {
887 expr_->SetMax(min_ - 1);
888 }
889
890 if (!expr_->IsVar() && (emax < min_ || emin > max_)) {
891 demon_->inhibit(solver());
892 }
893 }
894
895 std::string DebugString() const override {
896 return absl::StrFormat("NotBetweenCt(%s, %d, %d)", expr_->DebugString(),
897 min_, max_);
898 }
899
900 void Accept(ModelVisitor* const visitor) const override {
904 expr_);
907 }
908
909 private:
910 IntExpr* const expr_;
911 int64_t min_;
912 int64_t max_;
913 Demon* demon_;
914};
915
916int64_t ExtractExprProductCoeff(IntExpr** expr) {
917 int64_t prod = 1;
918 int64_t coeff = 1;
919 while ((*expr)->solver()->IsProduct(*expr, expr, &coeff)) prod *= coeff;
920 return prod;
921}
922} // namespace
923
924Constraint* Solver::MakeBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
925 DCHECK_EQ(this, expr->solver());
926 // Catch empty and singleton intervals.
927 if (l >= u) {
928 if (l > u) return MakeFalseConstraint();
929 return MakeEquality(expr, l);
930 }
931 int64_t emin = 0;
932 int64_t emax = 0;
933 expr->Range(&emin, &emax);
934 // Catch the trivial cases first.
935 if (emax < l || emin > u) return MakeFalseConstraint();
936 if (emin >= l && emax <= u) return MakeTrueConstraint();
937 // Catch one-sided constraints.
938 if (emax <= u) return MakeGreaterOrEqual(expr, l);
939 if (emin >= l) return MakeLessOrEqual(expr, u);
940 // Simplify the common factor, if any.
941 int64_t coeff = ExtractExprProductCoeff(&expr);
942 if (coeff != 1) {
943 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
944 if (coeff < 0) {
945 std::swap(u, l);
946 u = -u;
947 l = -l;
948 coeff = -coeff;
949 }
950 return MakeBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff));
951 } else {
952 // No further reduction is possible.
953 return RevAlloc(new BetweenCt(this, expr, l, u));
954 }
955}
956
957Constraint* Solver::MakeNotBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
958 DCHECK_EQ(this, expr->solver());
959 // Catch empty interval.
960 if (l > u) {
962 }
963
964 int64_t emin = 0;
965 int64_t emax = 0;
966 expr->Range(&emin, &emax);
967 // Catch the trivial cases first.
968 if (emax < l || emin > u) return MakeTrueConstraint();
969 if (emin >= l && emax <= u) return MakeFalseConstraint();
970 // Catch one-sided constraints.
971 if (emin >= l) return MakeGreater(expr, u);
972 if (emax <= u) return MakeLess(expr, l);
973 // TODO(user): Add back simplification code if expr is constant *
974 // other_expr.
975 return RevAlloc(new NotBetweenCt(this, expr, l, u));
976}
977
978// ----- is_between_cst Constraint -----
979
980namespace {
981class IsBetweenCt : public Constraint {
982 public:
983 IsBetweenCt(Solver* const s, IntExpr* const e, int64_t l, int64_t u,
984 IntVar* const b)
985 : Constraint(s),
986 expr_(e),
987 min_(l),
988 max_(u),
989 boolvar_(b),
990 demon_(nullptr) {}
991
992 void Post() override {
993 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
994 expr_->WhenRange(demon_);
995 boolvar_->WhenBound(demon_);
996 }
997
998 void InitialPropagate() override {
999 bool inhibit = false;
1000 int64_t emin = 0;
1001 int64_t emax = 0;
1002 expr_->Range(&emin, &emax);
1003 int64_t u = 1 - (emin > max_ || emax < min_);
1004 int64_t l = emax <= max_ && emin >= min_;
1005 boolvar_->SetRange(l, u);
1006 if (boolvar_->Bound()) {
1007 inhibit = true;
1008 if (boolvar_->Min() == 0) {
1009 if (expr_->IsVar()) {
1010 expr_->Var()->RemoveInterval(min_, max_);
1011 inhibit = true;
1012 } else if (emin > min_) {
1013 expr_->SetMin(max_ + 1);
1014 } else if (emax < max_) {
1015 expr_->SetMax(min_ - 1);
1016 }
1017 } else {
1018 expr_->SetRange(min_, max_);
1019 inhibit = true;
1020 }
1021 if (inhibit && expr_->IsVar()) {
1022 demon_->inhibit(solver());
1023 }
1024 }
1025 }
1026
1027 std::string DebugString() const override {
1028 return absl::StrFormat("IsBetweenCt(%s, %d, %d, %s)", expr_->DebugString(),
1029 min_, max_, boolvar_->DebugString());
1030 }
1031
1032 void Accept(ModelVisitor* const visitor) const override {
1033 visitor->BeginVisitConstraint(ModelVisitor::kIsBetween, this);
1034 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
1035 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1036 expr_);
1037 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
1038 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1039 boolvar_);
1040 visitor->EndVisitConstraint(ModelVisitor::kIsBetween, this);
1041 }
1042
1043 private:
1044 IntExpr* const expr_;
1045 int64_t min_;
1046 int64_t max_;
1047 IntVar* const boolvar_;
1048 Demon* demon_;
1049};
1050} // namespace
1051
1052Constraint* Solver::MakeIsBetweenCt(IntExpr* expr, int64_t l, int64_t u,
1053 IntVar* const b) {
1054 CHECK_EQ(this, expr->solver());
1055 CHECK_EQ(this, b->solver());
1056 // Catch empty and singleton intervals.
1057 if (l >= u) {
1058 if (l > u) return MakeEquality(b, Zero());
1059 return MakeIsEqualCstCt(expr, l, b);
1060 }
1061 int64_t emin = 0;
1062 int64_t emax = 0;
1063 expr->Range(&emin, &emax);
1064 // Catch the trivial cases first.
1065 if (emax < l || emin > u) return MakeEquality(b, Zero());
1066 if (emin >= l && emax <= u) return MakeEquality(b, 1);
1067 // Catch one-sided constraints.
1068 if (emax <= u) return MakeIsGreaterOrEqualCstCt(expr, l, b);
1069 if (emin >= l) return MakeIsLessOrEqualCstCt(expr, u, b);
1070 // Simplify the common factor, if any.
1071 int64_t coeff = ExtractExprProductCoeff(&expr);
1072 if (coeff != 1) {
1073 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
1074 if (coeff < 0) {
1075 std::swap(u, l);
1076 u = -u;
1077 l = -l;
1078 coeff = -coeff;
1079 }
1080 return MakeIsBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff),
1081 b);
1082 } else {
1083 // No further reduction is possible.
1084 return RevAlloc(new IsBetweenCt(this, expr, l, u, b));
1085 }
1086}
1087
1088IntVar* Solver::MakeIsBetweenVar(IntExpr* const v, int64_t l, int64_t u) {
1089 CHECK_EQ(this, v->solver());
1090 IntVar* const b = MakeBoolVar();
1091 AddConstraint(MakeIsBetweenCt(v, l, u, b));
1092 return b;
1093}
1094
1095// ---------- Member ----------
1096
1097// ----- Member(IntVar, IntSet) -----
1098
1099namespace {
1100// TODO(user): Do not create holes on expressions.
1101class MemberCt : public Constraint {
1102 public:
1103 MemberCt(Solver* const s, IntVar* const v,
1104 const std::vector<int64_t>& sorted_values)
1105 : Constraint(s), var_(v), values_(sorted_values) {
1106 DCHECK(v != nullptr);
1107 DCHECK(s != nullptr);
1108 }
1109
1110 void Post() override {}
1111
1112 void InitialPropagate() override { var_->SetValues(values_); }
1113
1114 std::string DebugString() const override {
1115 return absl::StrFormat("Member(%s, %s)", var_->DebugString(),
1116 absl::StrJoin(values_, ", "));
1117 }
1118
1119 void Accept(ModelVisitor* const visitor) const override {
1122 var_);
1125 }
1126
1127 private:
1128 IntVar* const var_;
1129 const std::vector<int64_t> values_;
1130};
1131
1132class NotMemberCt : public Constraint {
1133 public:
1134 NotMemberCt(Solver* const s, IntVar* const v,
1135 const std::vector<int64_t>& sorted_values)
1136 : Constraint(s), var_(v), values_(sorted_values) {
1137 DCHECK(v != nullptr);
1138 DCHECK(s != nullptr);
1139 }
1140
1141 void Post() override {}
1142
1143 void InitialPropagate() override { var_->RemoveValues(values_); }
1144
1145 std::string DebugString() const override {
1146 return absl::StrFormat("NotMember(%s, %s)", var_->DebugString(),
1147 absl::StrJoin(values_, ", "));
1148 }
1149
1150 void Accept(ModelVisitor* const visitor) const override {
1151 visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1152 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1153 var_);
1154 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1155 visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1156 }
1157
1158 private:
1159 IntVar* const var_;
1160 const std::vector<int64_t> values_;
1161};
1162} // namespace
1163
1164Constraint* Solver::MakeMemberCt(IntExpr* expr,
1165 const std::vector<int64_t>& values) {
1166 const int64_t coeff = ExtractExprProductCoeff(&expr);
1167 if (coeff == 0) {
1168 return std::find(values.begin(), values.end(), 0) == values.end()
1171 }
1172 std::vector<int64_t> copied_values = values;
1173 // If the expression is a non-trivial product, we filter out the values that
1174 // aren't multiples of "coeff", and divide them.
1175 if (coeff != 1) {
1176 int num_kept = 0;
1177 for (const int64_t v : copied_values) {
1178 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1179 }
1180 copied_values.resize(num_kept);
1181 }
1182 // Filter out the values that are outside the [Min, Max] interval.
1183 int num_kept = 0;
1184 int64_t emin;
1185 int64_t emax;
1186 expr->Range(&emin, &emax);
1187 for (const int64_t v : copied_values) {
1188 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1189 }
1190 copied_values.resize(num_kept);
1191 // Catch empty set.
1192 if (copied_values.empty()) return MakeFalseConstraint();
1193 // Sort and remove duplicates.
1194 gtl::STLSortAndRemoveDuplicates(&copied_values);
1195 // Special case for singleton.
1196 if (copied_values.size() == 1) return MakeEquality(expr, copied_values[0]);
1197 // Catch contiguous intervals.
1198 if (copied_values.size() ==
1199 copied_values.back() - copied_values.front() + 1) {
1200 // Note: MakeBetweenCt() has a fast-track for trivially true constraints.
1201 return MakeBetweenCt(expr, copied_values.front(), copied_values.back());
1202 }
1203 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1204 // "values" is smaller than "values", then it's more efficient to use
1205 // NotMemberCt. Catch that case here.
1206 if (emax - emin < 2 * copied_values.size()) {
1207 // Convert "copied_values" to list the values *not* allowed.
1208 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1209 for (const int64_t v : copied_values)
1210 is_among_input_values[v - emin] = true;
1211 // We use the zero valued indices of is_among_input_values to build the
1212 // complement of copied_values.
1213 copied_values.clear();
1214 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1215 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1216 }
1217 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1218 // "values" input) was caught earlier, by the "contiguous interval" case.
1219 DCHECK_GE(copied_values.size(), 1);
1220 if (copied_values.size() == 1) {
1221 return MakeNonEquality(expr, copied_values[0]);
1222 }
1223 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1224 }
1225 // Otherwise, just use MemberCt. No further reduction is possible.
1226 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1227}
1228
1229Constraint* Solver::MakeMemberCt(IntExpr* const expr,
1230 const std::vector<int>& values) {
1231 return MakeMemberCt(expr, ToInt64Vector(values));
1232}
1235 const std::vector<int64_t>& values) {
1236 const int64_t coeff = ExtractExprProductCoeff(&expr);
1237 if (coeff == 0) {
1238 return std::find(values.begin(), values.end(), 0) == values.end()
1241 }
1242 std::vector<int64_t> copied_values = values;
1243 // If the expression is a non-trivial product, we filter out the values that
1244 // aren't multiples of "coeff", and divide them.
1245 if (coeff != 1) {
1246 int num_kept = 0;
1247 for (const int64_t v : copied_values) {
1248 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1249 }
1250 copied_values.resize(num_kept);
1251 }
1252 // Filter out the values that are outside the [Min, Max] interval.
1253 int num_kept = 0;
1254 int64_t emin;
1255 int64_t emax;
1256 expr->Range(&emin, &emax);
1257 for (const int64_t v : copied_values) {
1258 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1259 }
1260 copied_values.resize(num_kept);
1261 // Catch empty set.
1262 if (copied_values.empty()) return MakeTrueConstraint();
1263 // Sort and remove duplicates.
1264 gtl::STLSortAndRemoveDuplicates(&copied_values);
1265 // Special case for singleton.
1266 if (copied_values.size() == 1) return MakeNonEquality(expr, copied_values[0]);
1267 // Catch contiguous intervals.
1268 if (copied_values.size() ==
1269 copied_values.back() - copied_values.front() + 1) {
1270 return MakeNotBetweenCt(expr, copied_values.front(), copied_values.back());
1271 }
1272 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1273 // "values" is smaller than "values", then it's more efficient to use
1274 // MemberCt. Catch that case here.
1275 if (emax - emin < 2 * copied_values.size()) {
1276 // Convert "copied_values" to a dense boolean vector.
1277 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1278 for (const int64_t v : copied_values)
1279 is_among_input_values[v - emin] = true;
1280 // Use zero valued indices for is_among_input_values to build the
1281 // complement of copied_values.
1282 copied_values.clear();
1283 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1284 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1285 }
1286 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1287 // "values" input) was caught earlier, by the "contiguous interval" case.
1288 DCHECK_GE(copied_values.size(), 1);
1289 if (copied_values.size() == 1) {
1290 return MakeEquality(expr, copied_values[0]);
1291 }
1292 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1293 }
1294 // Otherwise, just use NotMemberCt. No further reduction is possible.
1295 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1296}
1297
1298Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1299 const std::vector<int>& values) {
1300 return MakeNotMemberCt(expr, ToInt64Vector(values));
1301}
1303// ----- IsMemberCt -----
1304
1305namespace {
1306class IsMemberCt : public Constraint {
1307 public:
1308 IsMemberCt(Solver* const s, IntVar* const v,
1309 const std::vector<int64_t>& sorted_values, IntVar* const b)
1310 : Constraint(s),
1311 var_(v),
1312 values_as_set_(sorted_values.begin(), sorted_values.end()),
1313 values_(sorted_values),
1314 boolvar_(b),
1315 support_(0),
1316 demon_(nullptr),
1317 domain_(var_->MakeDomainIterator(true)),
1318 neg_support_(std::numeric_limits<int64_t>::min()) {
1319 DCHECK(v != nullptr);
1320 DCHECK(s != nullptr);
1321 DCHECK(b != nullptr);
1322 while (values_as_set_.contains(neg_support_)) {
1323 neg_support_++;
1324 }
1325 }
1326
1327 void Post() override {
1328 demon_ = MakeConstraintDemon0(solver(), this, &IsMemberCt::VarDomain,
1329 "VarDomain");
1330 if (!var_->Bound()) {
1331 var_->WhenDomain(demon_);
1332 }
1333 if (!boolvar_->Bound()) {
1334 Demon* const bdemon = MakeConstraintDemon0(
1335 solver(), this, &IsMemberCt::TargetBound, "TargetBound");
1336 boolvar_->WhenBound(bdemon);
1337 }
1338 }
1339
1340 void InitialPropagate() override {
1341 boolvar_->SetRange(0, 1);
1342 if (boolvar_->Bound()) {
1343 TargetBound();
1344 } else {
1345 VarDomain();
1346 }
1347 }
1348
1349 std::string DebugString() const override {
1350 return absl::StrFormat("IsMemberCt(%s, %s, %s)", var_->DebugString(),
1351 absl::StrJoin(values_, ", "),
1352 boolvar_->DebugString());
1353 }
1354
1355 void Accept(ModelVisitor* const visitor) const override {
1358 var_);
1361 boolvar_);
1363 }
1364
1365 private:
1366 void VarDomain() {
1367 if (boolvar_->Bound()) {
1368 TargetBound();
1369 } else {
1370 for (int offset = 0; offset < values_.size(); ++offset) {
1371 const int candidate = (support_ + offset) % values_.size();
1372 if (var_->Contains(values_[candidate])) {
1373 support_ = candidate;
1374 if (var_->Bound()) {
1375 demon_->inhibit(solver());
1376 boolvar_->SetValue(1);
1377 return;
1378 }
1379 // We have found a positive support. Let's check the
1380 // negative support.
1381 if (var_->Contains(neg_support_)) {
1382 return;
1383 } else {
1384 // Look for a new negative support.
1385 for (const int64_t value : InitAndGetValues(domain_)) {
1386 if (!values_as_set_.contains(value)) {
1387 neg_support_ = value;
1388 return;
1389 }
1390 }
1391 }
1392 // No negative support, setting boolvar to true.
1393 demon_->inhibit(solver());
1394 boolvar_->SetValue(1);
1395 return;
1396 }
1397 }
1398 // No positive support, setting boolvar to false.
1399 demon_->inhibit(solver());
1400 boolvar_->SetValue(0);
1401 }
1402 }
1403
1404 void TargetBound() {
1405 DCHECK(boolvar_->Bound());
1406 if (boolvar_->Min() == 1LL) {
1407 demon_->inhibit(solver());
1408 var_->SetValues(values_);
1409 } else {
1410 demon_->inhibit(solver());
1411 var_->RemoveValues(values_);
1412 }
1413 }
1414
1415 IntVar* const var_;
1416 absl::flat_hash_set<int64_t> values_as_set_;
1417 std::vector<int64_t> values_;
1418 IntVar* const boolvar_;
1419 int support_;
1420 Demon* demon_;
1421 IntVarIterator* const domain_;
1422 int64_t neg_support_;
1423};
1424
1425template <class T>
1426Constraint* BuildIsMemberCt(Solver* const solver, IntExpr* const expr,
1427 const std::vector<T>& values,
1428 IntVar* const boolvar) {
1429 // TODO(user): optimize this by copying the code from MakeMemberCt.
1430 // Simplify and filter if expr is a product.
1431 IntExpr* sub = nullptr;
1432 int64_t coef = 1;
1433 if (solver->IsProduct(expr, &sub, &coef) && coef != 0 && coef != 1) {
1434 std::vector<int64_t> new_values;
1435 new_values.reserve(values.size());
1436 for (const int64_t value : values) {
1437 if (value % coef == 0) {
1438 new_values.push_back(value / coef);
1439 }
1440 }
1441 return BuildIsMemberCt(solver, sub, new_values, boolvar);
1442 }
1443
1444 std::set<T> set_of_values(values.begin(), values.end());
1445 std::vector<int64_t> filtered_values;
1446 bool all_values = false;
1447 if (expr->IsVar()) {
1448 IntVar* const var = expr->Var();
1449 for (const T value : set_of_values) {
1450 if (var->Contains(value)) {
1451 filtered_values.push_back(value);
1452 }
1453 }
1454 all_values = (filtered_values.size() == var->Size());
1455 } else {
1456 int64_t emin = 0;
1457 int64_t emax = 0;
1458 expr->Range(&emin, &emax);
1459 for (const T value : set_of_values) {
1460 if (value >= emin && value <= emax) {
1461 filtered_values.push_back(value);
1462 }
1463 }
1464 all_values = (filtered_values.size() == emax - emin + 1);
1465 }
1466 if (filtered_values.empty()) {
1467 return solver->MakeEquality(boolvar, Zero());
1468 } else if (all_values) {
1469 return solver->MakeEquality(boolvar, 1);
1470 } else if (filtered_values.size() == 1) {
1471 return solver->MakeIsEqualCstCt(expr, filtered_values.back(), boolvar);
1472 } else if (filtered_values.back() ==
1473 filtered_values.front() + filtered_values.size() - 1) {
1474 // Contiguous
1475 return solver->MakeIsBetweenCt(expr, filtered_values.front(),
1476 filtered_values.back(), boolvar);
1477 } else {
1478 return solver->RevAlloc(
1479 new IsMemberCt(solver, expr->Var(), filtered_values, boolvar));
1480 }
1481}
1482} // namespace
1483
1485 const std::vector<int64_t>& values,
1486 IntVar* const boolvar) {
1487 return BuildIsMemberCt(this, expr, values, boolvar);
1489
1491 const std::vector<int>& values,
1492 IntVar* const boolvar) {
1493 return BuildIsMemberCt(this, expr, values, boolvar);
1495
1497 const std::vector<int64_t>& values) {
1498 IntVar* const b = MakeBoolVar();
1499 AddConstraint(MakeIsMemberCt(expr, values, b));
1500 return b;
1501}
1502
1504 const std::vector<int>& values) {
1505 IntVar* const b = MakeBoolVar();
1506 AddConstraint(MakeIsMemberCt(expr, values, b));
1507 return b;
1508}
1509
1510namespace {
1511class SortedDisjointForbiddenIntervalsConstraint : public Constraint {
1512 public:
1513 SortedDisjointForbiddenIntervalsConstraint(
1514 Solver* const solver, IntVar* const var,
1516 : Constraint(solver), var_(var), intervals_(std::move(intervals)) {}
1517
1518 ~SortedDisjointForbiddenIntervalsConstraint() override {}
1519
1520 void Post() override {
1521 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1522 var_->WhenRange(demon);
1523 }
1524
1525 void InitialPropagate() override {
1526 const int64_t vmin = var_->Min();
1527 const int64_t vmax = var_->Max();
1528 const auto first_interval_it = intervals_.FirstIntervalGreaterOrEqual(vmin);
1529 if (first_interval_it == intervals_.end()) {
1530 // No interval intersects the variable's range. Nothing to do.
1531 return;
1532 }
1533 const auto last_interval_it = intervals_.LastIntervalLessOrEqual(vmax);
1534 if (last_interval_it == intervals_.end()) {
1535 // No interval intersects the variable's range. Nothing to do.
1536 return;
1537 }
1538 // TODO(user): Quick fail if first_interval_it == last_interval_it, which
1539 // would imply that the interval contains the entire range of the variable?
1540 if (vmin >= first_interval_it->start) {
1541 // The variable's minimum is inside a forbidden interval. Move it to the
1542 // interval's end.
1543 var_->SetMin(CapAdd(first_interval_it->end, 1));
1544 }
1545 if (vmax <= last_interval_it->end) {
1546 // Ditto, on the other side.
1547 var_->SetMax(CapSub(last_interval_it->start, 1));
1548 }
1549 }
1550
1551 std::string DebugString() const override {
1552 return absl::StrFormat("ForbiddenIntervalCt(%s, %s)", var_->DebugString(),
1553 intervals_.DebugString());
1554 }
1555
1556 void Accept(ModelVisitor* const visitor) const override {
1559 var_);
1560 std::vector<int64_t> starts;
1561 std::vector<int64_t> ends;
1562 for (auto& interval : intervals_) {
1563 starts.push_back(interval.start);
1564 ends.push_back(interval.end);
1565 }
1569 }
1570
1571 private:
1572 IntVar* const var_;
1573 const SortedDisjointIntervalList intervals_;
1574};
1575} // namespace
1576
1577Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1578 std::vector<int64_t> starts,
1579 std::vector<int64_t> ends) {
1580 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1581 this, expr->Var(), {starts, ends}));
1582}
1583
1585 std::vector<int> starts,
1586 std::vector<int> ends) {
1587 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1588 this, expr->Var(), {starts, ends}));
1589}
1590
1592 SortedDisjointIntervalList intervals) {
1593 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1594 this, expr->Var(), std::move(intervals)));
1596} // namespace operations_research
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetMax(int64_t m)=0
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual int64_t Min() const =0
virtual void SetMin(int64_t m)=0
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
virtual void Range(int64_t *l, int64_t *u)
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
virtual void WhenBound(Demon *d)=0
virtual void WhenDomain(Demon *d)=0
virtual void SetValues(const std::vector< int64_t > &values)
This method intersects the current domain with the values in the array.
virtual IntVar * IsDifferent(int64_t constant)=0
IntVar * Var() override
Creates a variable from the expression.
virtual IntVar * IsLessOrEqual(int64_t constant)=0
virtual bool Contains(int64_t v) const =0
virtual void RemoveValue(int64_t v)=0
This method removes the value 'v' from the domain of the variable.
virtual uint64_t Size() const =0
This method returns the number of values in the domain of the variable.
virtual void RemoveInterval(int64_t l, int64_t u)=0
virtual void RemoveValues(const std::vector< int64_t > &values)
This method remove the values from the domain of the variable.
virtual void VisitIntegerArgument(const std::string &arg_name, int64_t value)
Visit integer arguments.
virtual void VisitIntegerExpressionArgument(const std::string &arg_name, IntExpr *argument)
Visit integer expression argument.
virtual void VisitIntegerArrayArgument(const std::string &arg_name, const std::vector< int64_t > &values)
virtual void BeginVisitConstraint(const std::string &type_name, const Constraint *constraint)
virtual void EndVisitConstraint(const std::string &type_name, const Constraint *constraint)
Constraint * MakeIsEqualCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var == value)
Definition expr_cst.cc:492
Constraint * MakeNotMemberCt(IntExpr *expr, const std::vector< int64_t > &values)
expr not in set.
Definition expr_cst.cc:1238
IntVar * MakeIsMemberVar(IntExpr *expr, const std::vector< int64_t > &values)
Definition expr_cst.cc:1500
Constraint * MakeNonEquality(IntExpr *left, IntExpr *right)
left != right
Definition range_cst.cc:569
IntVar * MakeIsEqualCstVar(IntExpr *var, int64_t value)
status var of (var == value)
Definition expr_cst.cc:467
IntExpr * MakeDifference(IntExpr *left, IntExpr *right)
left - right
Constraint * MakeMemberCt(IntExpr *expr, const std::vector< int64_t > &values)
Definition expr_cst.cc:1168
Constraint * MakeGreaterOrEqual(IntExpr *left, IntExpr *right)
left >= right
Definition range_cst.cc:547
IntVar * MakeIsLessOrEqualCstVar(IntExpr *var, int64_t value)
status var of (var <= value)
Definition expr_cst.cc:784
IntExpr * MakeSum(IntExpr *left, IntExpr *right)
left + right.
Constraint * MakeIsMemberCt(IntExpr *expr, const std::vector< int64_t > &values, IntVar *boolvar)
boolvar == (expr in set)
Definition expr_cst.cc:1488
IntVar * MakeIsDifferentVar(IntExpr *v1, IntExpr *v2)
status var of (v1 != v2)
Definition range_cst.cc:646
Constraint * MakeIsGreaterCstCt(IntExpr *v, int64_t c, IntVar *b)
b == (v > c)
Definition expr_cst.cc:721
IntVar * MakeIsGreaterCstVar(IntExpr *var, int64_t value)
status var of (var > value)
Definition expr_cst.cc:701
Constraint * MakeIsLessCstCt(IntExpr *v, int64_t c, IntVar *b)
b == (v < c)
Definition expr_cst.cc:821
Constraint * MakeLess(IntExpr *left, IntExpr *right)
left < right
Definition range_cst.cc:551
Constraint * MakeGreater(IntExpr *left, IntExpr *right)
left > right
Definition range_cst.cc:565
IntVar * MakeIsEqualVar(IntExpr *v1, IntExpr *v2)
status var of (v1 == v2)
Definition range_cst.cc:582
IntVar * MakeIsDifferentCstVar(IntExpr *var, int64_t value)
status var of (var != value)
Definition expr_cst.cc:585
IntVar * MakeIsBetweenVar(IntExpr *v, int64_t l, int64_t u)
Definition expr_cst.cc:1092
IntVar * MakeIsGreaterOrEqualCstVar(IntExpr *var, int64_t value)
status var of (var >= value)
Definition expr_cst.cc:684
void Accept(ModelVisitor *visitor) const
Accepts the given model visitor.
Constraint * MakeEquality(IntExpr *left, IntExpr *right)
left == right
Definition range_cst.cc:517
Constraint * MakeTrueConstraint()
This constraint always succeeds.
Constraint * MakeLessOrEqual(IntExpr *left, IntExpr *right)
left <= right
Definition range_cst.cc:531
Constraint * MakeIsGreaterOrEqualCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var >= value)
Definition expr_cst.cc:705
std::string DebugString() const
!defined(SWIG)
Constraint * MakeBetweenCt(IntExpr *expr, int64_t l, int64_t u)
(l <= expr <= u)
Definition expr_cst.cc:928
Demon * MakeConstraintInitialPropagateCallback(Constraint *ct)
Constraint * MakeFalseConstraint()
This constraint always fails.
bool IsProduct(IntExpr *expr, IntExpr **inner_expr, int64_t *coefficient)
IntVar * MakeIsLessCstVar(IntExpr *var, int64_t value)
status var of (var < value)
Definition expr_cst.cc:801
Constraint * MakeIsBetweenCt(IntExpr *expr, int64_t l, int64_t u, IntVar *b)
b == (l <= expr <= u)
Definition expr_cst.cc:1056
Constraint * MakeIsLessOrEqualCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var <= value)
Definition expr_cst.cc:805
IntVar * MakeIntConst(int64_t val, const std::string &name)
IntConst will create a constant expression.
Solver(const std::string &name)
Solver API.
Constraint * MakeNotBetweenCt(IntExpr *expr, int64_t l, int64_t u)
Definition expr_cst.cc:961
Constraint * MakeIsDifferentCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var != value)
Definition expr_cst.cc:594
ABSL_FLAG(int, cache_initial_size, 1024, "Initial size of the array of the hash " "table of caches for objects of type Var(x == 3)")
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition stl_util.h:55
OR-Tools root namespace.
int64_t CapAdd(int64_t x, int64_t y)
int64_t CapSub(int64_t x, int64_t y)
ClosedInterval::Iterator end(ClosedInterval interval)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
Definition utilities.cc:824
int64_t PosIntDivDown(int64_t e, int64_t v)
int64_t PosIntDivUp(int64_t e, int64_t v)
std::string DebugString() const
Definition model.cc:805