Google OR-Tools v9.11
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
expr_cst.cc
Go to the documentation of this file.
1// Copyright 2010-2024 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//
15// Expression constraints
16
17#include <algorithm>
18#include <cstddef>
19#include <cstdint>
20#include <limits>
21#include <set>
22#include <string>
23#include <utility>
24#include <vector>
25
26#include "absl/strings/str_format.h"
27#include "absl/strings/str_join.h"
31#include "ortools/base/types.h"
36
37ABSL_FLAG(int, cache_initial_size, 1024,
38 "Initial size of the array of the hash "
39 "table of caches for objects of type Var(x == 3)");
40
41namespace operations_research {
42
43//-----------------------------------------------------------------------------
44// Equality
45
46namespace {
47class EqualityExprCst : public Constraint {
48 public:
49 EqualityExprCst(Solver* s, IntExpr* e, int64_t v);
50 ~EqualityExprCst() override {}
51 void Post() override;
52 void InitialPropagate() override;
53 IntVar* Var() override {
54 return solver()->MakeIsEqualCstVar(expr_->Var(), value_);
55 }
56 std::string DebugString() const override;
57
58 void Accept(ModelVisitor* const visitor) const override {
59 visitor->BeginVisitConstraint(ModelVisitor::kEquality, this);
60 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
61 expr_);
62 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
63 visitor->EndVisitConstraint(ModelVisitor::kEquality, this);
64 }
65
66 private:
67 IntExpr* const expr_;
68 int64_t value_;
69};
70
71EqualityExprCst::EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v)
72 : Constraint(s), expr_(e), value_(v) {}
73
74void EqualityExprCst::Post() {
75 if (!expr_->IsVar()) {
76 Demon* d = solver()->MakeConstraintInitialPropagateCallback(this);
77 expr_->WhenRange(d);
78 }
79}
80
81void EqualityExprCst::InitialPropagate() { expr_->SetValue(value_); }
82
83std::string EqualityExprCst::DebugString() const {
84 return absl::StrFormat("(%s == %d)", expr_->DebugString(), value_);
85}
86} // namespace
87
88Constraint* Solver::MakeEquality(IntExpr* const e, int64_t v) {
89 CHECK_EQ(this, e->solver());
90 IntExpr* left = nullptr;
91 IntExpr* right = nullptr;
92 if (IsADifference(e, &left, &right)) {
93 return MakeEquality(left, MakeSum(right, v));
94 } else if (e->IsVar() && !e->Var()->Contains(v)) {
95 return MakeFalseConstraint();
96 } else if (e->Min() == e->Max() && e->Min() == v) {
97 return MakeTrueConstraint();
98 } else {
99 return RevAlloc(new EqualityExprCst(this, e, v));
100 }
101}
102
103Constraint* Solver::MakeEquality(IntExpr* const e, int v) {
104 CHECK_EQ(this, e->solver());
105 IntExpr* left = nullptr;
106 IntExpr* right = nullptr;
107 if (IsADifference(e, &left, &right)) {
108 return MakeEquality(left, MakeSum(right, v));
109 } else if (e->IsVar() && !e->Var()->Contains(v)) {
110 return MakeFalseConstraint();
111 } else if (e->Min() == e->Max() && e->Min() == v) {
112 return MakeTrueConstraint();
113 } else {
114 return RevAlloc(new EqualityExprCst(this, e, v));
115 }
116}
117
118//-----------------------------------------------------------------------------
119// Greater or equal constraint
120
121namespace {
122class GreaterEqExprCst : public Constraint {
123 public:
124 GreaterEqExprCst(Solver* s, IntExpr* e, int64_t v);
125 ~GreaterEqExprCst() override {}
126 void Post() override;
127 void InitialPropagate() override;
128 std::string DebugString() const override;
129 IntVar* Var() override {
130 return solver()->MakeIsGreaterOrEqualCstVar(expr_->Var(), value_);
131 }
132
133 void Accept(ModelVisitor* const visitor) const override {
134 visitor->BeginVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
135 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
136 expr_);
137 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
138 visitor->EndVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
139 }
140
141 private:
142 IntExpr* const expr_;
143 int64_t value_;
144 Demon* demon_;
145};
146
147GreaterEqExprCst::GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
148 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
149
150void GreaterEqExprCst::Post() {
151 if (!expr_->IsVar() && expr_->Min() < value_) {
152 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
153 expr_->WhenRange(demon_);
154 } else {
155 // Let's clean the demon in case the constraint is posted during search.
156 demon_ = nullptr;
157 }
158}
159
160void GreaterEqExprCst::InitialPropagate() {
161 expr_->SetMin(value_);
162 if (demon_ != nullptr && expr_->Min() >= value_) {
163 demon_->inhibit(solver());
164 }
165}
166
167std::string GreaterEqExprCst::DebugString() const {
168 return absl::StrFormat("(%s >= %d)", expr_->DebugString(), value_);
169}
170} // namespace
171
172Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int64_t v) {
173 CHECK_EQ(this, e->solver());
174 if (e->Min() >= v) {
175 return MakeTrueConstraint();
176 } else if (e->Max() < v) {
177 return MakeFalseConstraint();
178 } else {
179 return RevAlloc(new GreaterEqExprCst(this, e, v));
180 }
181}
182
183Constraint* Solver::MakeGreaterOrEqual(IntExpr* const e, int v) {
184 CHECK_EQ(this, e->solver());
185 if (e->Min() >= v) {
186 return MakeTrueConstraint();
187 } else if (e->Max() < v) {
188 return MakeFalseConstraint();
189 } else {
190 return RevAlloc(new GreaterEqExprCst(this, e, v));
191 }
192}
193
194Constraint* Solver::MakeGreater(IntExpr* const e, int64_t v) {
195 CHECK_EQ(this, e->solver());
196 if (e->Min() > v) {
197 return MakeTrueConstraint();
198 } else if (e->Max() <= v) {
199 return MakeFalseConstraint();
200 } else {
201 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
202 }
203}
204
205Constraint* Solver::MakeGreater(IntExpr* const e, int v) {
206 CHECK_EQ(this, e->solver());
207 if (e->Min() > v) {
208 return MakeTrueConstraint();
209 } else if (e->Max() <= v) {
210 return MakeFalseConstraint();
211 } else {
212 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
213 }
214}
215
216//-----------------------------------------------------------------------------
217// Less or equal constraint
218
219namespace {
220class LessEqExprCst : public Constraint {
221 public:
222 LessEqExprCst(Solver* s, IntExpr* e, int64_t v);
223 ~LessEqExprCst() override {}
224 void Post() override;
225 void InitialPropagate() override;
226 std::string DebugString() const override;
227 IntVar* Var() override {
228 return solver()->MakeIsLessOrEqualCstVar(expr_->Var(), value_);
229 }
230 void Accept(ModelVisitor* const visitor) const override {
231 visitor->BeginVisitConstraint(ModelVisitor::kLessOrEqual, this);
232 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
233 expr_);
234 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
235 visitor->EndVisitConstraint(ModelVisitor::kLessOrEqual, this);
236 }
237
238 private:
239 IntExpr* const expr_;
240 int64_t value_;
241 Demon* demon_;
242};
243
244LessEqExprCst::LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
245 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
246
247void LessEqExprCst::Post() {
248 if (!expr_->IsVar() && expr_->Max() > value_) {
249 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
250 expr_->WhenRange(demon_);
251 } else {
252 // Let's clean the demon in case the constraint is posted during search.
253 demon_ = nullptr;
254 }
255}
256
257void LessEqExprCst::InitialPropagate() {
258 expr_->SetMax(value_);
259 if (demon_ != nullptr && expr_->Max() <= value_) {
260 demon_->inhibit(solver());
261 }
262}
263
264std::string LessEqExprCst::DebugString() const {
265 return absl::StrFormat("(%s <= %d)", expr_->DebugString(), value_);
266}
267} // namespace
268
269Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int64_t v) {
270 CHECK_EQ(this, e->solver());
271 if (e->Max() <= v) {
272 return MakeTrueConstraint();
273 } else if (e->Min() > v) {
274 return MakeFalseConstraint();
275 } else {
276 return RevAlloc(new LessEqExprCst(this, e, v));
277 }
278}
279
280Constraint* Solver::MakeLessOrEqual(IntExpr* const e, int v) {
281 CHECK_EQ(this, e->solver());
282 if (e->Max() <= v) {
283 return MakeTrueConstraint();
284 } else if (e->Min() > v) {
285 return MakeFalseConstraint();
286 } else {
287 return RevAlloc(new LessEqExprCst(this, e, v));
288 }
289}
290
291Constraint* Solver::MakeLess(IntExpr* const e, int64_t v) {
292 CHECK_EQ(this, e->solver());
293 if (e->Max() < v) {
294 return MakeTrueConstraint();
295 } else if (e->Min() >= v) {
296 return MakeFalseConstraint();
297 } else {
298 return RevAlloc(new LessEqExprCst(this, e, v - 1));
299 }
300}
301
302Constraint* Solver::MakeLess(IntExpr* const e, int v) {
303 CHECK_EQ(this, e->solver());
304 if (e->Max() < v) {
305 return MakeTrueConstraint();
306 } else if (e->Min() >= v) {
307 return MakeFalseConstraint();
308 } else {
309 return RevAlloc(new LessEqExprCst(this, e, v - 1));
310 }
311}
312
313//-----------------------------------------------------------------------------
314// Different constraints
315
316namespace {
317class DiffCst : public Constraint {
318 public:
319 DiffCst(Solver* s, IntVar* var, int64_t value);
320 ~DiffCst() override {}
321 void Post() override {}
322 void InitialPropagate() override;
323 void BoundPropagate();
324 std::string DebugString() const override;
325 IntVar* Var() override {
326 return solver()->MakeIsDifferentCstVar(var_, value_);
327 }
328 void Accept(ModelVisitor* const visitor) const override {
329 visitor->BeginVisitConstraint(ModelVisitor::kNonEqual, this);
330 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
331 var_);
332 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
333 visitor->EndVisitConstraint(ModelVisitor::kNonEqual, this);
334 }
335
336 private:
337 bool HasLargeDomain(IntVar* var);
338
339 IntVar* const var_;
340 int64_t value_;
341 Demon* demon_;
342};
343
344DiffCst::DiffCst(Solver* const s, IntVar* const var, int64_t value)
345 : Constraint(s), var_(var), value_(value), demon_(nullptr) {}
346
347void DiffCst::InitialPropagate() {
348 if (HasLargeDomain(var_)) {
349 demon_ = MakeConstraintDemon0(solver(), this, &DiffCst::BoundPropagate,
350 "BoundPropagate");
351 var_->WhenRange(demon_);
352 } else {
353 var_->RemoveValue(value_);
354 }
355}
356
357void DiffCst::BoundPropagate() {
358 const int64_t var_min = var_->Min();
359 const int64_t var_max = var_->Max();
360 if (var_min > value_ || var_max < value_) {
361 demon_->inhibit(solver());
362 } else if (var_min == value_) {
363 var_->SetMin(CapAdd(value_, 1));
364 } else if (var_max == value_) {
365 var_->SetMax(CapSub(value_, 1));
366 } else if (!HasLargeDomain(var_)) {
367 demon_->inhibit(solver());
368 var_->RemoveValue(value_);
369 }
370}
371
372std::string DiffCst::DebugString() const {
373 return absl::StrFormat("(%s != %d)", var_->DebugString(), value_);
374}
375
376bool DiffCst::HasLargeDomain(IntVar* var) {
377 return CapSub(var->Max(), var->Min()) > 0xFFFFFF;
378}
379} // namespace
380
381Constraint* Solver::MakeNonEquality(IntExpr* const e, int64_t v) {
382 CHECK_EQ(this, e->solver());
383 IntExpr* left = nullptr;
384 IntExpr* right = nullptr;
385 if (IsADifference(e, &left, &right)) {
386 return MakeNonEquality(left, MakeSum(right, v));
387 } else if (e->IsVar() && !e->Var()->Contains(v)) {
388 return MakeTrueConstraint();
389 } else if (e->Bound() && e->Min() == v) {
390 return MakeFalseConstraint();
391 } else {
392 return RevAlloc(new DiffCst(this, e->Var(), v));
393 }
394}
395
396Constraint* Solver::MakeNonEquality(IntExpr* const e, int v) {
397 CHECK_EQ(this, e->solver());
398 IntExpr* left = nullptr;
399 IntExpr* right = nullptr;
400 if (IsADifference(e, &left, &right)) {
401 return MakeNonEquality(left, MakeSum(right, v));
402 } else if (e->IsVar() && !e->Var()->Contains(v)) {
403 return MakeTrueConstraint();
404 } else if (e->Bound() && e->Min() == v) {
405 return MakeFalseConstraint();
406 } else {
407 return RevAlloc(new DiffCst(this, e->Var(), v));
408 }
409}
410// ----- is_equal_cst Constraint -----
411
412namespace {
413class IsEqualCstCt : public CastConstraint {
414 public:
415 IsEqualCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
416 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
417 void Post() override {
418 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
419 var_->WhenDomain(demon_);
420 target_var_->WhenBound(demon_);
421 }
422 void InitialPropagate() override {
423 bool inhibit = var_->Bound();
424 int64_t u = var_->Contains(cst_);
425 int64_t l = inhibit ? u : 0;
426 target_var_->SetRange(l, u);
427 if (target_var_->Bound()) {
428 if (target_var_->Min() == 0) {
429 if (var_->Size() <= 0xFFFFFF) {
430 var_->RemoveValue(cst_);
431 inhibit = true;
432 }
433 } else {
434 var_->SetValue(cst_);
435 inhibit = true;
436 }
437 }
438 if (inhibit) {
439 demon_->inhibit(solver());
440 }
441 }
442 std::string DebugString() const override {
443 return absl::StrFormat("IsEqualCstCt(%s, %d, %s)", var_->DebugString(),
444 cst_, target_var_->DebugString());
445 }
446
447 void Accept(ModelVisitor* const visitor) const override {
448 visitor->BeginVisitConstraint(ModelVisitor::kIsEqual, this);
449 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
450 var_);
451 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
452 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
453 target_var_);
454 visitor->EndVisitConstraint(ModelVisitor::kIsEqual, this);
455 }
456
457 private:
458 IntVar* const var_;
459 int64_t cst_;
460 Demon* demon_;
461};
462} // namespace
463
464IntVar* Solver::MakeIsEqualCstVar(IntExpr* const var, int64_t value) {
465 IntExpr* left = nullptr;
466 IntExpr* right = nullptr;
467 if (IsADifference(var, &left, &right)) {
468 return MakeIsEqualVar(left, MakeSum(right, value));
469 }
470 if (CapSub(var->Max(), var->Min()) == 1) {
471 if (value == var->Min()) {
472 return MakeDifference(value + 1, var)->Var();
473 } else if (value == var->Max()) {
474 return MakeSum(var, -value + 1)->Var();
475 } else {
476 return MakeIntConst(0);
477 }
478 }
479 if (var->IsVar()) {
480 return var->Var()->IsEqual(value);
481 } else {
482 IntVar* const boolvar =
483 MakeBoolVar(absl::StrFormat("Is(%s == %d)", var->DebugString(), value));
484 AddConstraint(MakeIsEqualCstCt(var, value, boolvar));
485 return boolvar;
486 }
487}
488
489Constraint* Solver::MakeIsEqualCstCt(IntExpr* const var, int64_t value,
490 IntVar* const boolvar) {
491 CHECK_EQ(this, var->solver());
492 CHECK_EQ(this, boolvar->solver());
493 if (value == var->Min()) {
494 if (CapSub(var->Max(), var->Min()) == 1) {
495 return MakeEquality(MakeDifference(value + 1, var), boolvar);
496 }
497 return MakeIsLessOrEqualCstCt(var, value, boolvar);
498 }
499 if (value == var->Max()) {
500 if (CapSub(var->Max(), var->Min()) == 1) {
501 return MakeEquality(MakeSum(var, -value + 1), boolvar);
502 }
503 return MakeIsGreaterOrEqualCstCt(var, value, boolvar);
504 }
505 if (boolvar->Bound()) {
506 if (boolvar->Min() == 0) {
507 return MakeNonEquality(var, value);
508 } else {
509 return MakeEquality(var, value);
510 }
511 }
512 // TODO(user) : what happens if the constraint is not posted?
513 // The cache becomes tainted.
514 model_cache_->InsertExprConstantExpression(
516 IntExpr* left = nullptr;
517 IntExpr* right = nullptr;
518 if (IsADifference(var, &left, &right)) {
519 return MakeIsEqualCt(left, MakeSum(right, value), boolvar);
520 } else {
521 return RevAlloc(new IsEqualCstCt(this, var->Var(), value, boolvar));
522 }
523}
524
525// ----- is_diff_cst Constraint -----
526
527namespace {
528class IsDiffCstCt : public CastConstraint {
529 public:
530 IsDiffCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
531 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
532
533 void Post() override {
534 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
535 var_->WhenDomain(demon_);
536 target_var_->WhenBound(demon_);
537 }
538
539 void InitialPropagate() override {
540 bool inhibit = var_->Bound();
541 int64_t l = 1 - var_->Contains(cst_);
542 int64_t u = inhibit ? l : 1;
543 target_var_->SetRange(l, u);
544 if (target_var_->Bound()) {
545 if (target_var_->Min() == 1) {
546 if (var_->Size() <= 0xFFFFFF) {
547 var_->RemoveValue(cst_);
548 inhibit = true;
549 }
550 } else {
551 var_->SetValue(cst_);
552 inhibit = true;
553 }
554 }
555 if (inhibit) {
556 demon_->inhibit(solver());
557 }
558 }
559
560 std::string DebugString() const override {
561 return absl::StrFormat("IsDiffCstCt(%s, %d, %s)", var_->DebugString(), cst_,
562 target_var_->DebugString());
563 }
564
565 void Accept(ModelVisitor* const visitor) const override {
566 visitor->BeginVisitConstraint(ModelVisitor::kIsDifferent, this);
567 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
568 var_);
569 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
570 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
571 target_var_);
572 visitor->EndVisitConstraint(ModelVisitor::kIsDifferent, this);
573 }
574
575 private:
576 IntVar* const var_;
577 int64_t cst_;
578 Demon* demon_;
579};
580} // namespace
581
582IntVar* Solver::MakeIsDifferentCstVar(IntExpr* const var, int64_t value) {
583 IntExpr* left = nullptr;
584 IntExpr* right = nullptr;
585 if (IsADifference(var, &left, &right)) {
586 return MakeIsDifferentVar(left, MakeSum(right, value));
587 }
588 return var->Var()->IsDifferent(value);
589}
590
591Constraint* Solver::MakeIsDifferentCstCt(IntExpr* const var, int64_t value,
592 IntVar* const boolvar) {
593 CHECK_EQ(this, var->solver());
594 CHECK_EQ(this, boolvar->solver());
595 if (value == var->Min()) {
596 return MakeIsGreaterOrEqualCstCt(var, value + 1, boolvar);
597 }
598 if (value == var->Max()) {
599 return MakeIsLessOrEqualCstCt(var, value - 1, boolvar);
600 }
601 if (var->IsVar() && !var->Var()->Contains(value)) {
602 return MakeEquality(boolvar, int64_t{1});
603 }
604 if (var->Bound() && var->Min() == value) {
605 return MakeEquality(boolvar, Zero());
606 }
607 if (boolvar->Bound()) {
608 if (boolvar->Min() == 0) {
609 return MakeEquality(var, value);
610 } else {
611 return MakeNonEquality(var, value);
612 }
613 }
614 model_cache_->InsertExprConstantExpression(
616 IntExpr* left = nullptr;
617 IntExpr* right = nullptr;
618 if (IsADifference(var, &left, &right)) {
619 return MakeIsDifferentCt(left, MakeSum(right, value), boolvar);
620 } else {
621 return RevAlloc(new IsDiffCstCt(this, var->Var(), value, boolvar));
622 }
623}
624
625// ----- is_greater_equal_cst Constraint -----
626
627namespace {
628class IsGreaterEqualCstCt : public CastConstraint {
629 public:
630 IsGreaterEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
631 IntVar* const b)
632 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
633 void Post() override {
634 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
635 expr_->WhenRange(demon_);
636 target_var_->WhenBound(demon_);
637 }
638 void InitialPropagate() override {
639 bool inhibit = false;
640 int64_t u = expr_->Max() >= cst_;
641 int64_t l = expr_->Min() >= cst_;
642 target_var_->SetRange(l, u);
643 if (target_var_->Bound()) {
644 inhibit = true;
645 if (target_var_->Min() == 0) {
646 expr_->SetMax(cst_ - 1);
647 } else {
648 expr_->SetMin(cst_);
649 }
650 }
651 if (inhibit && ((target_var_->Max() == 0 && expr_->Max() < cst_) ||
652 (target_var_->Min() == 1 && expr_->Min() >= cst_))) {
653 // Can we safely inhibit? Sometimes an expression is not
654 // persistent, just monotonic.
655 demon_->inhibit(solver());
656 }
657 }
658 std::string DebugString() const override {
659 return absl::StrFormat("IsGreaterEqualCstCt(%s, %d, %s)",
660 expr_->DebugString(), cst_,
661 target_var_->DebugString());
662 }
663
664 void Accept(ModelVisitor* const visitor) const override {
665 visitor->BeginVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
666 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
667 expr_);
668 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
669 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
670 target_var_);
671 visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
672 }
673
674 private:
675 IntExpr* const expr_;
676 int64_t cst_;
677 Demon* demon_;
678};
679} // namespace
680
681IntVar* Solver::MakeIsGreaterOrEqualCstVar(IntExpr* const var, int64_t value) {
682 if (var->Min() >= value) {
683 return MakeIntConst(int64_t{1});
684 }
685 if (var->Max() < value) {
686 return MakeIntConst(int64_t{0});
687 }
688 if (var->IsVar()) {
689 return var->Var()->IsGreaterOrEqual(value);
690 } else {
691 IntVar* const boolvar =
692 MakeBoolVar(absl::StrFormat("Is(%s >= %d)", var->DebugString(), value));
693 AddConstraint(MakeIsGreaterOrEqualCstCt(var, value, boolvar));
694 return boolvar;
695 }
696}
697
698IntVar* Solver::MakeIsGreaterCstVar(IntExpr* const var, int64_t value) {
699 return MakeIsGreaterOrEqualCstVar(var, value + 1);
700}
701
703 IntVar* const boolvar) {
704 if (boolvar->Bound()) {
705 if (boolvar->Min() == 0) {
706 return MakeLess(var, value);
707 } else {
708 return MakeGreaterOrEqual(var, value);
709 }
710 }
711 CHECK_EQ(this, var->solver());
712 CHECK_EQ(this, boolvar->solver());
713 model_cache_->InsertExprConstantExpression(
715 return RevAlloc(new IsGreaterEqualCstCt(this, var, value, boolvar));
716}
717
718Constraint* Solver::MakeIsGreaterCstCt(IntExpr* const v, int64_t c,
719 IntVar* const b) {
720 return MakeIsGreaterOrEqualCstCt(v, c + 1, b);
721}
723// ----- is_lesser_equal_cst Constraint -----
724
725namespace {
726class IsLessEqualCstCt : public CastConstraint {
727 public:
728 IsLessEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
729 IntVar* const b)
730 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
731
732 void Post() override {
733 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
734 expr_->WhenRange(demon_);
735 target_var_->WhenBound(demon_);
736 }
737
738 void InitialPropagate() override {
739 bool inhibit = false;
740 int64_t u = expr_->Min() <= cst_;
741 int64_t l = expr_->Max() <= cst_;
742 target_var_->SetRange(l, u);
743 if (target_var_->Bound()) {
744 inhibit = true;
745 if (target_var_->Min() == 0) {
746 expr_->SetMin(cst_ + 1);
747 } else {
748 expr_->SetMax(cst_);
749 }
750 }
751 if (inhibit && ((target_var_->Max() == 0 && expr_->Min() > cst_) ||
752 (target_var_->Min() == 1 && expr_->Max() <= cst_))) {
753 // Can we safely inhibit? Sometimes an expression is not
754 // persistent, just monotonic.
755 demon_->inhibit(solver());
756 }
757 }
758
759 std::string DebugString() const override {
760 return absl::StrFormat("IsLessEqualCstCt(%s, %d, %s)", expr_->DebugString(),
761 cst_, target_var_->DebugString());
762 }
763
764 void Accept(ModelVisitor* const visitor) const override {
765 visitor->BeginVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
766 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
767 expr_);
768 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
769 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
770 target_var_);
771 visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
772 }
773
774 private:
775 IntExpr* const expr_;
776 int64_t cst_;
777 Demon* demon_;
778};
779} // namespace
780
782 if (var->Max() <= value) {
783 return MakeIntConst(int64_t{1});
784 }
785 if (var->Min() > value) {
786 return MakeIntConst(int64_t{0});
787 }
788 if (var->IsVar()) {
789 return var->Var()->IsLessOrEqual(value);
790 } else {
791 IntVar* const boolvar =
792 MakeBoolVar(absl::StrFormat("Is(%s <= %d)", var->DebugString(), value));
793 AddConstraint(MakeIsLessOrEqualCstCt(var, value, boolvar));
794 return boolvar;
795 }
796}
797
798IntVar* Solver::MakeIsLessCstVar(IntExpr* const var, int64_t value) {
799 return MakeIsLessOrEqualCstVar(var, value - 1);
800}
801
803 IntVar* const boolvar) {
804 if (boolvar->Bound()) {
805 if (boolvar->Min() == 0) {
806 return MakeGreater(var, value);
807 } else {
808 return MakeLessOrEqual(var, value);
809 }
810 }
811 CHECK_EQ(this, var->solver());
812 CHECK_EQ(this, boolvar->solver());
813 model_cache_->InsertExprConstantExpression(
815 return RevAlloc(new IsLessEqualCstCt(this, var, value, boolvar));
816}
817
818Constraint* Solver::MakeIsLessCstCt(IntExpr* const v, int64_t c,
819 IntVar* const b) {
820 return MakeIsLessOrEqualCstCt(v, c - 1, b);
821}
823// ----- BetweenCt -----
824
825namespace {
826class BetweenCt : public Constraint {
827 public:
828 BetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
829 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
830
831 void Post() override {
832 if (!expr_->IsVar()) {
833 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
834 expr_->WhenRange(demon_);
835 }
836 }
837
838 void InitialPropagate() override {
839 expr_->SetRange(min_, max_);
840 int64_t emin = 0;
841 int64_t emax = 0;
842 expr_->Range(&emin, &emax);
843 if (demon_ != nullptr && emin >= min_ && emax <= max_) {
844 demon_->inhibit(solver());
845 }
846 }
847
848 std::string DebugString() const override {
849 return absl::StrFormat("BetweenCt(%s, %d, %d)", expr_->DebugString(), min_,
850 max_);
851 }
852
853 void Accept(ModelVisitor* const visitor) const override {
854 visitor->BeginVisitConstraint(ModelVisitor::kBetween, this);
855 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
856 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
857 expr_);
858 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
859 visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
860 }
861
862 private:
863 IntExpr* const expr_;
864 int64_t min_;
865 int64_t max_;
866 Demon* demon_;
867};
868
869// ----- NonMember constraint -----
870
871class NotBetweenCt : public Constraint {
872 public:
873 NotBetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
874 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
875
876 void Post() override {
877 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
878 expr_->WhenRange(demon_);
879 }
880
881 void InitialPropagate() override {
882 int64_t emin = 0;
883 int64_t emax = 0;
884 expr_->Range(&emin, &emax);
885 if (emin >= min_) {
886 expr_->SetMin(max_ + 1);
887 } else if (emax <= max_) {
888 expr_->SetMax(min_ - 1);
889 }
890
891 if (!expr_->IsVar() && (emax < min_ || emin > max_)) {
892 demon_->inhibit(solver());
893 }
894 }
895
896 std::string DebugString() const override {
897 return absl::StrFormat("NotBetweenCt(%s, %d, %d)", expr_->DebugString(),
898 min_, max_);
899 }
900
901 void Accept(ModelVisitor* const visitor) const override {
902 visitor->BeginVisitConstraint(ModelVisitor::kNotBetween, this);
903 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
904 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
905 expr_);
906 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
907 visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
908 }
909
910 private:
911 IntExpr* const expr_;
912 int64_t min_;
913 int64_t max_;
914 Demon* demon_;
915};
916
917int64_t ExtractExprProductCoeff(IntExpr** expr) {
918 int64_t prod = 1;
919 int64_t coeff = 1;
920 while ((*expr)->solver()->IsProduct(*expr, expr, &coeff)) prod *= coeff;
921 return prod;
922}
923} // namespace
924
925Constraint* Solver::MakeBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
926 DCHECK_EQ(this, expr->solver());
927 // Catch empty and singleton intervals.
928 if (l >= u) {
929 if (l > u) return MakeFalseConstraint();
930 return MakeEquality(expr, l);
931 }
932 int64_t emin = 0;
933 int64_t emax = 0;
934 expr->Range(&emin, &emax);
935 // Catch the trivial cases first.
936 if (emax < l || emin > u) return MakeFalseConstraint();
937 if (emin >= l && emax <= u) return MakeTrueConstraint();
938 // Catch one-sided constraints.
939 if (emax <= u) return MakeGreaterOrEqual(expr, l);
940 if (emin >= l) return MakeLessOrEqual(expr, u);
941 // Simplify the common factor, if any.
942 int64_t coeff = ExtractExprProductCoeff(&expr);
943 if (coeff != 1) {
944 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
945 if (coeff < 0) {
946 std::swap(u, l);
947 u = -u;
948 l = -l;
949 coeff = -coeff;
950 }
951 return MakeBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff));
952 } else {
953 // No further reduction is possible.
954 return RevAlloc(new BetweenCt(this, expr, l, u));
955 }
956}
957
958Constraint* Solver::MakeNotBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
959 DCHECK_EQ(this, expr->solver());
960 // Catch empty interval.
961 if (l > u) {
962 return MakeTrueConstraint();
963 }
964
965 int64_t emin = 0;
966 int64_t emax = 0;
967 expr->Range(&emin, &emax);
968 // Catch the trivial cases first.
969 if (emax < l || emin > u) return MakeTrueConstraint();
970 if (emin >= l && emax <= u) return MakeFalseConstraint();
971 // Catch one-sided constraints.
972 if (emin >= l) return MakeGreater(expr, u);
973 if (emax <= u) return MakeLess(expr, l);
974 // TODO(user): Add back simplification code if expr is constant *
975 // other_expr.
976 return RevAlloc(new NotBetweenCt(this, expr, l, u));
977}
978
979// ----- is_between_cst Constraint -----
980
981namespace {
982class IsBetweenCt : public Constraint {
983 public:
984 IsBetweenCt(Solver* const s, IntExpr* const e, int64_t l, int64_t u,
985 IntVar* const b)
986 : Constraint(s),
987 expr_(e),
988 min_(l),
989 max_(u),
990 boolvar_(b),
991 demon_(nullptr) {}
992
993 void Post() override {
994 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
995 expr_->WhenRange(demon_);
996 boolvar_->WhenBound(demon_);
997 }
998
999 void InitialPropagate() override {
1000 bool inhibit = false;
1001 int64_t emin = 0;
1002 int64_t emax = 0;
1003 expr_->Range(&emin, &emax);
1004 int64_t u = 1 - (emin > max_ || emax < min_);
1005 int64_t l = emax <= max_ && emin >= min_;
1006 boolvar_->SetRange(l, u);
1007 if (boolvar_->Bound()) {
1008 inhibit = true;
1009 if (boolvar_->Min() == 0) {
1010 if (expr_->IsVar()) {
1011 expr_->Var()->RemoveInterval(min_, max_);
1012 inhibit = true;
1013 } else if (emin > min_) {
1014 expr_->SetMin(max_ + 1);
1015 } else if (emax < max_) {
1016 expr_->SetMax(min_ - 1);
1017 }
1018 } else {
1019 expr_->SetRange(min_, max_);
1020 inhibit = true;
1021 }
1022 if (inhibit && expr_->IsVar()) {
1023 demon_->inhibit(solver());
1024 }
1025 }
1026 }
1027
1028 std::string DebugString() const override {
1029 return absl::StrFormat("IsBetweenCt(%s, %d, %d, %s)", expr_->DebugString(),
1030 min_, max_, boolvar_->DebugString());
1031 }
1032
1033 void Accept(ModelVisitor* const visitor) const override {
1034 visitor->BeginVisitConstraint(ModelVisitor::kIsBetween, this);
1035 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
1036 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1037 expr_);
1038 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
1039 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1040 boolvar_);
1041 visitor->EndVisitConstraint(ModelVisitor::kIsBetween, this);
1042 }
1043
1044 private:
1045 IntExpr* const expr_;
1046 int64_t min_;
1047 int64_t max_;
1048 IntVar* const boolvar_;
1049 Demon* demon_;
1050};
1051} // namespace
1052
1053Constraint* Solver::MakeIsBetweenCt(IntExpr* expr, int64_t l, int64_t u,
1054 IntVar* const b) {
1055 CHECK_EQ(this, expr->solver());
1056 CHECK_EQ(this, b->solver());
1057 // Catch empty and singleton intervals.
1058 if (l >= u) {
1059 if (l > u) return MakeEquality(b, Zero());
1060 return MakeIsEqualCstCt(expr, l, b);
1061 }
1062 int64_t emin = 0;
1063 int64_t emax = 0;
1064 expr->Range(&emin, &emax);
1065 // Catch the trivial cases first.
1066 if (emax < l || emin > u) return MakeEquality(b, Zero());
1067 if (emin >= l && emax <= u) return MakeEquality(b, 1);
1068 // Catch one-sided constraints.
1069 if (emax <= u) return MakeIsGreaterOrEqualCstCt(expr, l, b);
1070 if (emin >= l) return MakeIsLessOrEqualCstCt(expr, u, b);
1071 // Simplify the common factor, if any.
1072 int64_t coeff = ExtractExprProductCoeff(&expr);
1073 if (coeff != 1) {
1074 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
1075 if (coeff < 0) {
1076 std::swap(u, l);
1077 u = -u;
1078 l = -l;
1079 coeff = -coeff;
1080 }
1081 return MakeIsBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff),
1082 b);
1083 } else {
1084 // No further reduction is possible.
1085 return RevAlloc(new IsBetweenCt(this, expr, l, u, b));
1086 }
1087}
1088
1089IntVar* Solver::MakeIsBetweenVar(IntExpr* const v, int64_t l, int64_t u) {
1090 CHECK_EQ(this, v->solver());
1091 IntVar* const b = MakeBoolVar();
1092 AddConstraint(MakeIsBetweenCt(v, l, u, b));
1093 return b;
1094}
1095
1096// ---------- Member ----------
1097
1098// ----- Member(IntVar, IntSet) -----
1099
1100namespace {
1101// TODO(user): Do not create holes on expressions.
1102class MemberCt : public Constraint {
1103 public:
1104 MemberCt(Solver* const s, IntVar* const v,
1105 const std::vector<int64_t>& sorted_values)
1106 : Constraint(s), var_(v), values_(sorted_values) {
1107 DCHECK(v != nullptr);
1108 DCHECK(s != nullptr);
1109 }
1110
1111 void Post() override {}
1112
1113 void InitialPropagate() override { var_->SetValues(values_); }
1114
1115 std::string DebugString() const override {
1116 return absl::StrFormat("Member(%s, %s)", var_->DebugString(),
1117 absl::StrJoin(values_, ", "));
1118 }
1119
1120 void Accept(ModelVisitor* const visitor) const override {
1121 visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1122 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1123 var_);
1124 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1125 visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1126 }
1127
1128 private:
1129 IntVar* const var_;
1130 const std::vector<int64_t> values_;
1131};
1132
1133class NotMemberCt : public Constraint {
1134 public:
1135 NotMemberCt(Solver* const s, IntVar* const v,
1136 const std::vector<int64_t>& sorted_values)
1137 : Constraint(s), var_(v), values_(sorted_values) {
1138 DCHECK(v != nullptr);
1139 DCHECK(s != nullptr);
1140 }
1141
1142 void Post() override {}
1143
1144 void InitialPropagate() override { var_->RemoveValues(values_); }
1145
1146 std::string DebugString() const override {
1147 return absl::StrFormat("NotMember(%s, %s)", var_->DebugString(),
1148 absl::StrJoin(values_, ", "));
1149 }
1150
1151 void Accept(ModelVisitor* const visitor) const override {
1152 visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1153 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1154 var_);
1155 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1156 visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1157 }
1158
1159 private:
1160 IntVar* const var_;
1161 const std::vector<int64_t> values_;
1162};
1163} // namespace
1164
1165Constraint* Solver::MakeMemberCt(IntExpr* expr,
1166 const std::vector<int64_t>& values) {
1167 const int64_t coeff = ExtractExprProductCoeff(&expr);
1168 if (coeff == 0) {
1169 return std::find(values.begin(), values.end(), 0) == values.end()
1170 ? MakeFalseConstraint()
1171 : MakeTrueConstraint();
1172 }
1173 std::vector<int64_t> copied_values = values;
1174 // If the expression is a non-trivial product, we filter out the values that
1175 // aren't multiples of "coeff", and divide them.
1176 if (coeff != 1) {
1177 int num_kept = 0;
1178 for (const int64_t v : copied_values) {
1179 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1180 }
1181 copied_values.resize(num_kept);
1182 }
1183 // Filter out the values that are outside the [Min, Max] interval.
1184 int num_kept = 0;
1185 int64_t emin;
1186 int64_t emax;
1187 expr->Range(&emin, &emax);
1188 for (const int64_t v : copied_values) {
1189 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1190 }
1191 copied_values.resize(num_kept);
1192 // Catch empty set.
1193 if (copied_values.empty()) return MakeFalseConstraint();
1194 // Sort and remove duplicates.
1195 gtl::STLSortAndRemoveDuplicates(&copied_values);
1196 // Special case for singleton.
1197 if (copied_values.size() == 1) return MakeEquality(expr, copied_values[0]);
1198 // Catch contiguous intervals.
1199 if (copied_values.size() ==
1200 copied_values.back() - copied_values.front() + 1) {
1201 // Note: MakeBetweenCt() has a fast-track for trivially true constraints.
1202 return MakeBetweenCt(expr, copied_values.front(), copied_values.back());
1203 }
1204 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1205 // "values" is smaller than "values", then it's more efficient to use
1206 // NotMemberCt. Catch that case here.
1207 if (emax - emin < 2 * copied_values.size()) {
1208 // Convert "copied_values" to list the values *not* allowed.
1209 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1210 for (const int64_t v : copied_values)
1211 is_among_input_values[v - emin] = true;
1212 // We use the zero valued indices of is_among_input_values to build the
1213 // complement of copied_values.
1214 copied_values.clear();
1215 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1216 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1217 }
1218 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1219 // "values" input) was caught earlier, by the "contiguous interval" case.
1220 DCHECK_GE(copied_values.size(), 1);
1221 if (copied_values.size() == 1) {
1222 return MakeNonEquality(expr, copied_values[0]);
1223 }
1224 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1225 }
1226 // Otherwise, just use MemberCt. No further reduction is possible.
1227 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1228}
1229
1230Constraint* Solver::MakeMemberCt(IntExpr* const expr,
1231 const std::vector<int>& values) {
1232 return MakeMemberCt(expr, ToInt64Vector(values));
1233}
1236 const std::vector<int64_t>& values) {
1237 const int64_t coeff = ExtractExprProductCoeff(&expr);
1238 if (coeff == 0) {
1239 return std::find(values.begin(), values.end(), 0) == values.end()
1240 ? MakeTrueConstraint()
1241 : MakeFalseConstraint();
1242 }
1243 std::vector<int64_t> copied_values = values;
1244 // If the expression is a non-trivial product, we filter out the values that
1245 // aren't multiples of "coeff", and divide them.
1246 if (coeff != 1) {
1247 int num_kept = 0;
1248 for (const int64_t v : copied_values) {
1249 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1250 }
1251 copied_values.resize(num_kept);
1252 }
1253 // Filter out the values that are outside the [Min, Max] interval.
1254 int num_kept = 0;
1255 int64_t emin;
1256 int64_t emax;
1257 expr->Range(&emin, &emax);
1258 for (const int64_t v : copied_values) {
1259 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1260 }
1261 copied_values.resize(num_kept);
1262 // Catch empty set.
1263 if (copied_values.empty()) return MakeTrueConstraint();
1264 // Sort and remove duplicates.
1265 gtl::STLSortAndRemoveDuplicates(&copied_values);
1266 // Special case for singleton.
1267 if (copied_values.size() == 1) return MakeNonEquality(expr, copied_values[0]);
1268 // Catch contiguous intervals.
1269 if (copied_values.size() ==
1270 copied_values.back() - copied_values.front() + 1) {
1271 return MakeNotBetweenCt(expr, copied_values.front(), copied_values.back());
1272 }
1273 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1274 // "values" is smaller than "values", then it's more efficient to use
1275 // MemberCt. Catch that case here.
1276 if (emax - emin < 2 * copied_values.size()) {
1277 // Convert "copied_values" to a dense boolean vector.
1278 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1279 for (const int64_t v : copied_values)
1280 is_among_input_values[v - emin] = true;
1281 // Use zero valued indices for is_among_input_values to build the
1282 // complement of copied_values.
1283 copied_values.clear();
1284 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1285 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1286 }
1287 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1288 // "values" input) was caught earlier, by the "contiguous interval" case.
1289 DCHECK_GE(copied_values.size(), 1);
1290 if (copied_values.size() == 1) {
1291 return MakeEquality(expr, copied_values[0]);
1292 }
1293 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1294 }
1295 // Otherwise, just use NotMemberCt. No further reduction is possible.
1296 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1297}
1298
1299Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1300 const std::vector<int>& values) {
1301 return MakeNotMemberCt(expr, ToInt64Vector(values));
1302}
1304// ----- IsMemberCt -----
1305
1306namespace {
1307class IsMemberCt : public Constraint {
1308 public:
1309 IsMemberCt(Solver* const s, IntVar* const v,
1310 const std::vector<int64_t>& sorted_values, IntVar* const b)
1311 : Constraint(s),
1312 var_(v),
1313 values_as_set_(sorted_values.begin(), sorted_values.end()),
1314 values_(sorted_values),
1315 boolvar_(b),
1316 support_(0),
1317 demon_(nullptr),
1318 domain_(var_->MakeDomainIterator(true)),
1319 neg_support_(std::numeric_limits<int64_t>::min()) {
1320 DCHECK(v != nullptr);
1321 DCHECK(s != nullptr);
1322 DCHECK(b != nullptr);
1323 while (values_as_set_.contains(neg_support_)) {
1324 neg_support_++;
1325 }
1326 }
1327
1328 void Post() override {
1329 demon_ = MakeConstraintDemon0(solver(), this, &IsMemberCt::VarDomain,
1330 "VarDomain");
1331 if (!var_->Bound()) {
1332 var_->WhenDomain(demon_);
1333 }
1334 if (!boolvar_->Bound()) {
1335 Demon* const bdemon = MakeConstraintDemon0(
1336 solver(), this, &IsMemberCt::TargetBound, "TargetBound");
1337 boolvar_->WhenBound(bdemon);
1338 }
1339 }
1340
1341 void InitialPropagate() override {
1342 boolvar_->SetRange(0, 1);
1343 if (boolvar_->Bound()) {
1344 TargetBound();
1345 } else {
1346 VarDomain();
1347 }
1348 }
1349
1350 std::string DebugString() const override {
1351 return absl::StrFormat("IsMemberCt(%s, %s, %s)", var_->DebugString(),
1352 absl::StrJoin(values_, ", "),
1353 boolvar_->DebugString());
1354 }
1355
1356 void Accept(ModelVisitor* const visitor) const override {
1357 visitor->BeginVisitConstraint(ModelVisitor::kIsMember, this);
1358 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1359 var_);
1360 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1361 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1362 boolvar_);
1363 visitor->EndVisitConstraint(ModelVisitor::kIsMember, this);
1364 }
1365
1366 private:
1367 void VarDomain() {
1368 if (boolvar_->Bound()) {
1369 TargetBound();
1370 } else {
1371 for (int offset = 0; offset < values_.size(); ++offset) {
1372 const int candidate = (support_ + offset) % values_.size();
1373 if (var_->Contains(values_[candidate])) {
1374 support_ = candidate;
1375 if (var_->Bound()) {
1376 demon_->inhibit(solver());
1377 boolvar_->SetValue(1);
1378 return;
1379 }
1380 // We have found a positive support. Let's check the
1381 // negative support.
1382 if (var_->Contains(neg_support_)) {
1383 return;
1384 } else {
1385 // Look for a new negative support.
1386 for (const int64_t value : InitAndGetValues(domain_)) {
1387 if (!values_as_set_.contains(value)) {
1388 neg_support_ = value;
1389 return;
1390 }
1391 }
1392 }
1393 // No negative support, setting boolvar to true.
1394 demon_->inhibit(solver());
1395 boolvar_->SetValue(1);
1396 return;
1397 }
1398 }
1399 // No positive support, setting boolvar to false.
1400 demon_->inhibit(solver());
1401 boolvar_->SetValue(0);
1402 }
1403 }
1404
1405 void TargetBound() {
1406 DCHECK(boolvar_->Bound());
1407 if (boolvar_->Min() == 1LL) {
1408 demon_->inhibit(solver());
1409 var_->SetValues(values_);
1410 } else {
1411 demon_->inhibit(solver());
1412 var_->RemoveValues(values_);
1413 }
1414 }
1415
1416 IntVar* const var_;
1417 absl::flat_hash_set<int64_t> values_as_set_;
1418 std::vector<int64_t> values_;
1419 IntVar* const boolvar_;
1420 int support_;
1421 Demon* demon_;
1422 IntVarIterator* const domain_;
1423 int64_t neg_support_;
1424};
1425
1426template <class T>
1427Constraint* BuildIsMemberCt(Solver* const solver, IntExpr* const expr,
1428 const std::vector<T>& values,
1429 IntVar* const boolvar) {
1430 // TODO(user): optimize this by copying the code from MakeMemberCt.
1431 // Simplify and filter if expr is a product.
1432 IntExpr* sub = nullptr;
1433 int64_t coef = 1;
1434 if (solver->IsProduct(expr, &sub, &coef) && coef != 0 && coef != 1) {
1435 std::vector<int64_t> new_values;
1436 new_values.reserve(values.size());
1437 for (const int64_t value : values) {
1438 if (value % coef == 0) {
1439 new_values.push_back(value / coef);
1440 }
1441 }
1442 return BuildIsMemberCt(solver, sub, new_values, boolvar);
1443 }
1444
1445 std::set<T> set_of_values(values.begin(), values.end());
1446 std::vector<int64_t> filtered_values;
1447 bool all_values = false;
1448 if (expr->IsVar()) {
1449 IntVar* const var = expr->Var();
1450 for (const T value : set_of_values) {
1451 if (var->Contains(value)) {
1452 filtered_values.push_back(value);
1453 }
1454 }
1455 all_values = (filtered_values.size() == var->Size());
1456 } else {
1457 int64_t emin = 0;
1458 int64_t emax = 0;
1459 expr->Range(&emin, &emax);
1460 for (const T value : set_of_values) {
1461 if (value >= emin && value <= emax) {
1462 filtered_values.push_back(value);
1463 }
1464 }
1465 all_values = (filtered_values.size() == emax - emin + 1);
1466 }
1467 if (filtered_values.empty()) {
1468 return solver->MakeEquality(boolvar, Zero());
1469 } else if (all_values) {
1470 return solver->MakeEquality(boolvar, 1);
1471 } else if (filtered_values.size() == 1) {
1472 return solver->MakeIsEqualCstCt(expr, filtered_values.back(), boolvar);
1473 } else if (filtered_values.back() ==
1474 filtered_values.front() + filtered_values.size() - 1) {
1475 // Contiguous
1476 return solver->MakeIsBetweenCt(expr, filtered_values.front(),
1477 filtered_values.back(), boolvar);
1478 } else {
1479 return solver->RevAlloc(
1480 new IsMemberCt(solver, expr->Var(), filtered_values, boolvar));
1481 }
1482}
1483} // namespace
1484
1486 const std::vector<int64_t>& values,
1487 IntVar* const boolvar) {
1488 return BuildIsMemberCt(this, expr, values, boolvar);
1490
1492 const std::vector<int>& values,
1493 IntVar* const boolvar) {
1494 return BuildIsMemberCt(this, expr, values, boolvar);
1496
1498 const std::vector<int64_t>& values) {
1499 IntVar* const b = MakeBoolVar();
1500 AddConstraint(MakeIsMemberCt(expr, values, b));
1501 return b;
1502}
1503
1505 const std::vector<int>& values) {
1506 IntVar* const b = MakeBoolVar();
1507 AddConstraint(MakeIsMemberCt(expr, values, b));
1508 return b;
1509}
1510
1511namespace {
1512class SortedDisjointForbiddenIntervalsConstraint : public Constraint {
1513 public:
1514 SortedDisjointForbiddenIntervalsConstraint(
1515 Solver* const solver, IntVar* const var,
1517 : Constraint(solver), var_(var), intervals_(std::move(intervals)) {}
1518
1519 ~SortedDisjointForbiddenIntervalsConstraint() override {}
1520
1521 void Post() override {
1522 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1523 var_->WhenRange(demon);
1524 }
1525
1526 void InitialPropagate() override {
1527 const int64_t vmin = var_->Min();
1528 const int64_t vmax = var_->Max();
1529 const auto first_interval_it = intervals_.FirstIntervalGreaterOrEqual(vmin);
1530 if (first_interval_it == intervals_.end()) {
1531 // No interval intersects the variable's range. Nothing to do.
1532 return;
1533 }
1534 const auto last_interval_it = intervals_.LastIntervalLessOrEqual(vmax);
1535 if (last_interval_it == intervals_.end()) {
1536 // No interval intersects the variable's range. Nothing to do.
1537 return;
1538 }
1539 // TODO(user): Quick fail if first_interval_it == last_interval_it, which
1540 // would imply that the interval contains the entire range of the variable?
1541 if (vmin >= first_interval_it->start) {
1542 // The variable's minimum is inside a forbidden interval. Move it to the
1543 // interval's end.
1544 var_->SetMin(CapAdd(first_interval_it->end, 1));
1545 }
1546 if (vmax <= last_interval_it->end) {
1547 // Ditto, on the other side.
1548 var_->SetMax(CapSub(last_interval_it->start, 1));
1549 }
1550 }
1551
1552 std::string DebugString() const override {
1553 return absl::StrFormat("ForbiddenIntervalCt(%s, %s)", var_->DebugString(),
1554 intervals_.DebugString());
1555 }
1556
1557 void Accept(ModelVisitor* const visitor) const override {
1558 visitor->BeginVisitConstraint(ModelVisitor::kNotMember, this);
1559 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1560 var_);
1561 std::vector<int64_t> starts;
1562 std::vector<int64_t> ends;
1563 for (auto& interval : intervals_) {
1564 starts.push_back(interval.start);
1565 ends.push_back(interval.end);
1566 }
1567 visitor->VisitIntegerArrayArgument(ModelVisitor::kStartsArgument, starts);
1568 visitor->VisitIntegerArrayArgument(ModelVisitor::kEndsArgument, ends);
1569 visitor->EndVisitConstraint(ModelVisitor::kNotMember, this);
1570 }
1571
1572 private:
1573 IntVar* const var_;
1574 const SortedDisjointIntervalList intervals_;
1575};
1576} // namespace
1577
1578Constraint* Solver::MakeNotMemberCt(IntExpr* const expr,
1579 std::vector<int64_t> starts,
1580 std::vector<int64_t> ends) {
1581 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1582 this, expr->Var(), {starts, ends}));
1583}
1584
1586 std::vector<int> starts,
1587 std::vector<int> ends) {
1588 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1589 this, expr->Var(), {starts, ends}));
1590}
1591
1593 SortedDisjointIntervalList intervals) {
1594 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1595 this, expr->Var(), std::move(intervals)));
1597} // namespace operations_research
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetMax(int64_t m)=0
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual int64_t Min() const =0
virtual void SetMin(int64_t m)=0
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
virtual void Range(int64_t *l, int64_t *u)
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
virtual void WhenBound(Demon *d)=0
virtual void WhenDomain(Demon *d)=0
virtual void SetValues(const std::vector< int64_t > &values)
This method intersects the current domain with the values in the array.
virtual IntVar * IsDifferent(int64_t constant)=0
IntVar * Var() override
Creates a variable from the expression.
virtual bool Contains(int64_t v) const =0
virtual void RemoveValue(int64_t v)=0
This method removes the value 'v' from the domain of the variable.
virtual uint64_t Size() const =0
This method returns the number of values in the domain of the variable.
virtual void RemoveValues(const std::vector< int64_t > &values)
This method remove the values from the domain of the variable.
virtual void VisitIntegerArgument(const std::string &arg_name, int64_t value)
Visit integer arguments.
virtual void VisitIntegerExpressionArgument(const std::string &arg_name, IntExpr *argument)
Visit integer expression argument.
virtual void VisitIntegerArrayArgument(const std::string &arg_name, const std::vector< int64_t > &values)
virtual void BeginVisitConstraint(const std::string &type_name, const Constraint *constraint)
virtual void EndVisitConstraint(const std::string &type_name, const Constraint *constraint)
For the time being, Solver is neither MT_SAFE nor MT_HOT.
Constraint * MakeIsEqualCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var == value)
Definition expr_cst.cc:493
Constraint * MakeNotMemberCt(IntExpr *expr, const std::vector< int64_t > &values)
expr not in set.
Definition expr_cst.cc:1239
IntVar * MakeIsMemberVar(IntExpr *expr, const std::vector< int64_t > &values)
Definition expr_cst.cc:1501
Constraint * MakeNonEquality(IntExpr *left, IntExpr *right)
left != right
Definition range_cst.cc:570
IntVar * MakeIsEqualCstVar(IntExpr *var, int64_t value)
status var of (var == value)
Definition expr_cst.cc:468
Constraint * MakeMemberCt(IntExpr *expr, const std::vector< int64_t > &values)
--— Member and related constraints --—
Definition expr_cst.cc:1169
Constraint * MakeGreaterOrEqual(IntExpr *left, IntExpr *right)
left >= right
Definition range_cst.cc:548
IntVar * MakeIsLessOrEqualCstVar(IntExpr *var, int64_t value)
status var of (var <= value)
Definition expr_cst.cc:785
Constraint * MakeIsMemberCt(IntExpr *expr, const std::vector< int64_t > &values, IntVar *boolvar)
boolvar == (expr in set)
Definition expr_cst.cc:1489
Constraint * MakeIsGreaterCstCt(IntExpr *v, int64_t c, IntVar *b)
b == (v > c)
Definition expr_cst.cc:722
IntVar * MakeIsGreaterCstVar(IntExpr *var, int64_t value)
status var of (var > value)
Definition expr_cst.cc:702
Constraint * MakeIsLessCstCt(IntExpr *v, int64_t c, IntVar *b)
b == (v < c)
Definition expr_cst.cc:822
Constraint * MakeLess(IntExpr *left, IntExpr *right)
left < right
Definition range_cst.cc:552
Constraint * MakeGreater(IntExpr *left, IntExpr *right)
left > right
Definition range_cst.cc:566
IntVar * MakeIsDifferentCstVar(IntExpr *var, int64_t value)
status var of (var != value)
Definition expr_cst.cc:586
IntVar * MakeIsBetweenVar(IntExpr *v, int64_t l, int64_t u)
Definition expr_cst.cc:1093
IntVar * MakeIsGreaterOrEqualCstVar(IntExpr *var, int64_t value)
status var of (var >= value)
Definition expr_cst.cc:685
Constraint * MakeEquality(IntExpr *left, IntExpr *right)
left == right
Definition range_cst.cc:518
Constraint * MakeLessOrEqual(IntExpr *left, IntExpr *right)
left <= right
Definition range_cst.cc:532
Constraint * MakeIsGreaterOrEqualCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var >= value)
Definition expr_cst.cc:706
Constraint * MakeBetweenCt(IntExpr *expr, int64_t l, int64_t u)
--— Between and related constraints --—
Definition expr_cst.cc:929
Demon * MakeConstraintInitialPropagateCallback(Constraint *ct)
bool IsProduct(IntExpr *expr, IntExpr **inner_expr, int64_t *coefficient)
IntVar * MakeIsLessCstVar(IntExpr *var, int64_t value)
status var of (var < value)
Definition expr_cst.cc:802
Constraint * MakeIsBetweenCt(IntExpr *expr, int64_t l, int64_t u, IntVar *b)
b == (l <= expr <= u)
Definition expr_cst.cc:1057
Constraint * MakeIsLessOrEqualCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var <= value)
Definition expr_cst.cc:806
Constraint * MakeNotBetweenCt(IntExpr *expr, int64_t l, int64_t u)
Definition expr_cst.cc:962
Constraint * MakeIsDifferentCstCt(IntExpr *var, int64_t value, IntVar *boolvar)
boolvar == (var != value)
Definition expr_cst.cc:595
int64_t b
Definition table.cc:45
int64_t value
IntVar *const expr_
Definition element.cc:88
IntVar * var
int64_t coef
ABSL_FLAG(int, cache_initial_size, 1024, "Initial size of the array of the hash " "table of caches for objects of type Var(x == 3)")
Expression constraints.
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition stl_util.h:58
In SWIG mode, we don't want anything besides these top-level includes.
int64_t CapAdd(int64_t x, int64_t y)
int64_t CapSub(int64_t x, int64_t y)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
--— Vector manipulations --—
Definition utilities.cc:829
int64_t PosIntDivDown(int64_t e, int64_t v)
int64_t PosIntDivUp(int64_t e, int64_t v)
IntervalVar * interval
Definition resource.cc:101
std::optional< int64_t > end