Google OR-Tools v9.11
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
element.cc
Go to the documentation of this file.
1// Copyright 2010-2024 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cstdint>
16#include <functional>
17#include <limits>
18#include <memory>
19#include <numeric>
20#include <string>
21#include <utility>
22#include <vector>
23
24#include "absl/strings/str_format.h"
25#include "absl/strings/str_join.h"
27#include "ortools/base/types.h"
32
33ABSL_FLAG(bool, cp_disable_element_cache, true,
34 "If true, caching for IntElement is disabled.");
35
36namespace operations_research {
37
38// ----- IntExprElement -----
39void LinkVarExpr(Solver* s, IntExpr* expr, IntVar* var);
40
41namespace {
42
43template <class T>
44class VectorLess {
45 public:
46 explicit VectorLess(const std::vector<T>* values) : values_(values) {}
47 bool operator()(const T& x, const T& y) const {
48 return (*values_)[x] < (*values_)[y];
49 }
50
51 private:
52 const std::vector<T>* values_;
53};
54
55template <class T>
56class VectorGreater {
57 public:
58 explicit VectorGreater(const std::vector<T>* values) : values_(values) {}
59 bool operator()(const T& x, const T& y) const {
60 return (*values_)[x] > (*values_)[y];
61 }
62
63 private:
64 const std::vector<T>* values_;
65};
66
67// ----- BaseIntExprElement -----
68
69class BaseIntExprElement : public BaseIntExpr {
70 public:
71 BaseIntExprElement(Solver* s, IntVar* e);
72 ~BaseIntExprElement() override {}
73 int64_t Min() const override;
74 int64_t Max() const override;
75 void Range(int64_t* mi, int64_t* ma) override;
76 void SetMin(int64_t m) override;
77 void SetMax(int64_t m) override;
78 void SetRange(int64_t mi, int64_t ma) override;
79 bool Bound() const override { return (expr_->Bound()); }
80 // TODO(user) : improve me, the previous test is not always true
81 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
82
83 protected:
84 virtual int64_t ElementValue(int index) const = 0;
85 virtual int64_t ExprMin() const = 0;
86 virtual int64_t ExprMax() const = 0;
87
88 IntVar* const expr_;
89
90 private:
91 void UpdateSupports() const;
92 template <typename T>
93 void UpdateElementIndexBounds(T check_value) {
94 const int64_t emin = ExprMin();
95 const int64_t emax = ExprMax();
96 int64_t nmin = emin;
97 int64_t value = ElementValue(nmin);
98 while (nmin < emax && check_value(value)) {
99 nmin++;
100 value = ElementValue(nmin);
101 }
102 if (nmin == emax && check_value(value)) {
103 solver()->Fail();
104 }
105 int64_t nmax = emax;
106 value = ElementValue(nmax);
107 while (nmax >= nmin && check_value(value)) {
108 nmax--;
109 value = ElementValue(nmax);
110 }
111 expr_->SetRange(nmin, nmax);
112 }
113
114 mutable int64_t min_;
115 mutable int min_support_;
116 mutable int64_t max_;
117 mutable int max_support_;
118 mutable bool initial_update_;
119 IntVarIterator* const expr_iterator_;
120};
121
122BaseIntExprElement::BaseIntExprElement(Solver* const s, IntVar* const e)
123 : BaseIntExpr(s),
124 expr_(e),
125 min_(0),
126 min_support_(-1),
127 max_(0),
128 max_support_(-1),
129 initial_update_(true),
130 expr_iterator_(expr_->MakeDomainIterator(true)) {
131 CHECK(s != nullptr);
132 CHECK(e != nullptr);
133}
134
135int64_t BaseIntExprElement::Min() const {
136 UpdateSupports();
137 return min_;
138}
139
140int64_t BaseIntExprElement::Max() const {
141 UpdateSupports();
142 return max_;
143}
144
145void BaseIntExprElement::Range(int64_t* mi, int64_t* ma) {
146 UpdateSupports();
147 *mi = min_;
148 *ma = max_;
149}
150
151void BaseIntExprElement::SetMin(int64_t m) {
152 UpdateElementIndexBounds([m](int64_t value) { return value < m; });
153}
154
155void BaseIntExprElement::SetMax(int64_t m) {
156 UpdateElementIndexBounds([m](int64_t value) { return value > m; });
157}
158
159void BaseIntExprElement::SetRange(int64_t mi, int64_t ma) {
160 if (mi > ma) {
161 solver()->Fail();
162 }
163 UpdateElementIndexBounds(
164 [mi, ma](int64_t value) { return value < mi || value > ma; });
165}
166
167void BaseIntExprElement::UpdateSupports() const {
168 if (initial_update_ || !expr_->Contains(min_support_) ||
169 !expr_->Contains(max_support_)) {
170 const int64_t emin = ExprMin();
171 const int64_t emax = ExprMax();
172 int64_t min_value = ElementValue(emax);
173 int64_t max_value = min_value;
174 int min_support = emax;
175 int max_support = emax;
176 const uint64_t expr_size = expr_->Size();
177 if (expr_size > 1) {
178 if (expr_size == emax - emin + 1) {
179 // Value(emax) already stored in min_value, max_value.
180 for (int64_t index = emin; index < emax; ++index) {
181 const int64_t value = ElementValue(index);
182 if (value > max_value) {
183 max_value = value;
184 max_support = index;
185 } else if (value < min_value) {
186 min_value = value;
187 min_support = index;
188 }
189 }
190 } else {
191 for (const int64_t index : InitAndGetValues(expr_iterator_)) {
192 if (index >= emin && index <= emax) {
193 const int64_t value = ElementValue(index);
194 if (value > max_value) {
195 max_value = value;
196 max_support = index;
197 } else if (value < min_value) {
198 min_value = value;
199 min_support = index;
200 }
201 }
202 }
203 }
204 }
205 Solver* s = solver();
206 s->SaveAndSetValue(&min_, min_value);
207 s->SaveAndSetValue(&min_support_, min_support);
208 s->SaveAndSetValue(&max_, max_value);
209 s->SaveAndSetValue(&max_support_, max_support);
210 s->SaveAndSetValue(&initial_update_, false);
211 }
212}
213
214// ----- IntElementConstraint -----
215
216// This constraint implements 'elem' == 'values'['index'].
217// It scans the bounds of 'elem' to propagate on the domain of 'index'.
218// It scans the domain of 'index' to compute the new bounds of 'elem'.
219class IntElementConstraint : public CastConstraint {
220 public:
221 IntElementConstraint(Solver* const s, const std::vector<int64_t>& values,
222 IntVar* const index, IntVar* const elem)
223 : CastConstraint(s, elem),
224 values_(values),
225 index_(index),
226 index_iterator_(index_->MakeDomainIterator(true)) {
227 CHECK(index != nullptr);
228 }
229
230 void Post() override {
231 Demon* const d =
232 solver()->MakeDelayedConstraintInitialPropagateCallback(this);
233 index_->WhenDomain(d);
234 target_var_->WhenRange(d);
235 }
236
237 void InitialPropagate() override {
238 index_->SetRange(0, values_.size() - 1);
239 const int64_t target_var_min = target_var_->Min();
240 const int64_t target_var_max = target_var_->Max();
241 int64_t new_min = target_var_max;
242 int64_t new_max = target_var_min;
243 to_remove_.clear();
244 for (const int64_t index : InitAndGetValues(index_iterator_)) {
245 const int64_t value = values_[index];
246 if (value < target_var_min || value > target_var_max) {
247 to_remove_.push_back(index);
248 } else {
249 if (value < new_min) {
250 new_min = value;
251 }
252 if (value > new_max) {
253 new_max = value;
254 }
255 }
256 }
257 target_var_->SetRange(new_min, new_max);
258 if (!to_remove_.empty()) {
259 index_->RemoveValues(to_remove_);
260 }
261 }
262
263 std::string DebugString() const override {
264 return absl::StrFormat("IntElementConstraint(%s, %s, %s)",
265 absl::StrJoin(values_, ", "), index_->DebugString(),
266 target_var_->DebugString());
267 }
268
269 void Accept(ModelVisitor* const visitor) const override {
270 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
271 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
272 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
273 index_);
274 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
275 target_var_);
276 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
277 }
278
279 private:
280 const std::vector<int64_t> values_;
281 IntVar* const index_;
282 IntVarIterator* const index_iterator_;
283 std::vector<int64_t> to_remove_;
284};
285
286// ----- IntExprElement
287
288IntVar* BuildDomainIntVar(Solver* solver, std::vector<int64_t>* values);
289
290class IntExprElement : public BaseIntExprElement {
291 public:
292 IntExprElement(Solver* const s, const std::vector<int64_t>& vals,
293 IntVar* const expr)
294 : BaseIntExprElement(s, expr), values_(vals) {}
295
296 ~IntExprElement() override {}
297
298 std::string name() const override {
299 const int size = values_.size();
300 if (size > 10) {
301 return absl::StrFormat("IntElement(array of size %d, %s)", size,
302 expr_->name());
303 } else {
304 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
305 expr_->name());
306 }
307 }
308
309 std::string DebugString() const override {
310 const int size = values_.size();
311 if (size > 10) {
312 return absl::StrFormat("IntElement(array of size %d, %s)", size,
313 expr_->DebugString());
314 } else {
315 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
316 expr_->DebugString());
317 }
318 }
319
320 IntVar* CastToVar() override {
321 Solver* const s = solver();
322 IntVar* const var = s->MakeIntVar(values_);
323 s->AddCastConstraint(
324 s->RevAlloc(new IntElementConstraint(s, values_, expr_, var)), var,
325 this);
326 return var;
327 }
328
329 void Accept(ModelVisitor* const visitor) const override {
330 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
331 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
332 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
333 expr_);
334 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
335 }
336
337 protected:
338 int64_t ElementValue(int index) const override {
339 DCHECK_LT(index, values_.size());
340 return values_[index];
341 }
342 int64_t ExprMin() const override {
343 return std::max<int64_t>(0, expr_->Min());
344 }
345 int64_t ExprMax() const override {
346 return values_.empty()
347 ? 0
348 : std::min<int64_t>(values_.size() - 1, expr_->Max());
349 }
350
351 private:
352 const std::vector<int64_t> values_;
353};
354
355// ----- Range Minimum Query-based Element -----
356
357class RangeMinimumQueryExprElement : public BaseIntExpr {
358 public:
359 RangeMinimumQueryExprElement(Solver* solver,
360 const std::vector<int64_t>& values,
361 IntVar* index);
362 ~RangeMinimumQueryExprElement() override {}
363 int64_t Min() const override;
364 int64_t Max() const override;
365 void Range(int64_t* mi, int64_t* ma) override;
366 void SetMin(int64_t m) override;
367 void SetMax(int64_t m) override;
368 void SetRange(int64_t mi, int64_t ma) override;
369 bool Bound() const override { return (index_->Bound()); }
370 // TODO(user) : improve me, the previous test is not always true
371 void WhenRange(Demon* d) override { index_->WhenRange(d); }
372 IntVar* CastToVar() override {
373 // TODO(user): Should we try to make holes in the domain of index_, as we
374 // do here, or should we only propagate bounds as we do in
375 // IncreasingIntExprElement ?
376 IntVar* const var = solver()->MakeIntVar(min_rmq_.array());
377 solver()->AddCastConstraint(solver()->RevAlloc(new IntElementConstraint(
378 solver(), min_rmq_.array(), index_, var)),
379 var, this);
380 return var;
381 }
382 void Accept(ModelVisitor* const visitor) const override {
383 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
384 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
385 min_rmq_.array());
386 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
387 index_);
388 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
389 }
390
391 private:
392 int64_t IndexMin() const { return std::max<int64_t>(0, index_->Min()); }
393 int64_t IndexMax() const {
394 return std::min<int64_t>(min_rmq_.array().size() - 1, index_->Max());
395 }
396
397 IntVar* const index_;
398 const RangeMinimumQuery<int64_t, std::less<int64_t>> min_rmq_;
399 const RangeMinimumQuery<int64_t, std::greater<int64_t>> max_rmq_;
400};
401
402RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(
403 Solver* solver, const std::vector<int64_t>& values, IntVar* index)
404 : BaseIntExpr(solver), index_(index), min_rmq_(values), max_rmq_(values) {
405 CHECK(solver != nullptr);
406 CHECK(index != nullptr);
407}
408
409int64_t RangeMinimumQueryExprElement::Min() const {
410 return min_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
411}
412
413int64_t RangeMinimumQueryExprElement::Max() const {
414 return max_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
415}
416
417void RangeMinimumQueryExprElement::Range(int64_t* mi, int64_t* ma) {
418 const int64_t range_min = IndexMin();
419 const int64_t range_max = IndexMax() + 1;
420 *mi = min_rmq_.GetMinimumFromRange(range_min, range_max);
421 *ma = max_rmq_.GetMinimumFromRange(range_min, range_max);
422}
423
424#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test) \
425 const std::vector<int64_t>& values = min_rmq_.array(); \
426 int64_t index_min = IndexMin(); \
427 int64_t index_max = IndexMax(); \
428 int64_t value = values[index_min]; \
429 while (index_min < index_max && (test)) { \
430 index_min++; \
431 value = values[index_min]; \
432 } \
433 if (index_min == index_max && (test)) { \
434 solver()->Fail(); \
435 } \
436 value = values[index_max]; \
437 while (index_max >= index_min && (test)) { \
438 index_max--; \
439 value = values[index_max]; \
440 } \
441 index_->SetRange(index_min, index_max);
442
443void RangeMinimumQueryExprElement::SetMin(int64_t m) {
445}
446
447void RangeMinimumQueryExprElement::SetMax(int64_t m) {
449}
450
451void RangeMinimumQueryExprElement::SetRange(int64_t mi, int64_t ma) {
452 if (mi > ma) {
453 solver()->Fail();
454 }
456}
457
458#undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS
459
460// ----- Increasing Element -----
461
462class IncreasingIntExprElement : public BaseIntExpr {
463 public:
464 IncreasingIntExprElement(Solver* s, const std::vector<int64_t>& values,
465 IntVar* index);
466 ~IncreasingIntExprElement() override {}
467
468 int64_t Min() const override;
469 void SetMin(int64_t m) override;
470 int64_t Max() const override;
471 void SetMax(int64_t m) override;
472 void SetRange(int64_t mi, int64_t ma) override;
473 bool Bound() const override { return (index_->Bound()); }
474 // TODO(user) : improve me, the previous test is not always true
475 std::string name() const override {
476 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
477 index_->name());
478 }
479 std::string DebugString() const override {
480 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
481 index_->DebugString());
482 }
483
484 void Accept(ModelVisitor* const visitor) const override {
485 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
486 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
487 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
488 index_);
489 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
490 }
491
492 void WhenRange(Demon* d) override { index_->WhenRange(d); }
493
494 IntVar* CastToVar() override {
495 Solver* const s = solver();
496 IntVar* const var = s->MakeIntVar(values_);
497 LinkVarExpr(s, this, var);
498 return var;
499 }
500
501 private:
502 const std::vector<int64_t> values_;
503 IntVar* const index_;
504};
505
506IncreasingIntExprElement::IncreasingIntExprElement(
507 Solver* const s, const std::vector<int64_t>& values, IntVar* const index)
508 : BaseIntExpr(s), values_(values), index_(index) {
509 DCHECK(index);
510 DCHECK(s);
511}
512
513int64_t IncreasingIntExprElement::Min() const {
514 const int64_t expression_min = std::max<int64_t>(0, index_->Min());
515 return (expression_min < values_.size()
516 ? values_[expression_min]
517 : std::numeric_limits<int64_t>::max());
518}
519
520void IncreasingIntExprElement::SetMin(int64_t m) {
521 const int64_t index_min = std::max<int64_t>(0, index_->Min());
522 const int64_t index_max =
523 std::min<int64_t>(values_.size() - 1, index_->Max());
524
525 if (index_min > index_max || m > values_[index_max]) {
526 solver()->Fail();
527 }
528
529 const std::vector<int64_t>::const_iterator first =
530 std::lower_bound(values_.begin(), values_.end(), m);
531 const int64_t new_index_min = first - values_.begin();
532 index_->SetMin(new_index_min);
533}
534
535int64_t IncreasingIntExprElement::Max() const {
536 const int64_t expression_max =
537 std::min<int64_t>(values_.size() - 1, index_->Max());
538 return (expression_max >= 0 ? values_[expression_max]
539 : std::numeric_limits<int64_t>::max());
540}
541
542void IncreasingIntExprElement::SetMax(int64_t m) {
543 int64_t index_min = std::max<int64_t>(0, index_->Min());
544 if (m < values_[index_min]) {
545 solver()->Fail();
546 }
547
548 const std::vector<int64_t>::const_iterator last_after =
549 std::upper_bound(values_.begin(), values_.end(), m);
550 const int64_t new_index_max = (last_after - values_.begin()) - 1;
551 index_->SetRange(0, new_index_max);
552}
553
554void IncreasingIntExprElement::SetRange(int64_t mi, int64_t ma) {
555 if (mi > ma) {
556 solver()->Fail();
557 }
558 const int64_t index_min = std::max<int64_t>(0, index_->Min());
559 const int64_t index_max =
560 std::min<int64_t>(values_.size() - 1, index_->Max());
561
562 if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {
563 solver()->Fail();
564 }
565
566 const std::vector<int64_t>::const_iterator first =
567 std::lower_bound(values_.begin(), values_.end(), mi);
568 const int64_t new_index_min = first - values_.begin();
569
570 const std::vector<int64_t>::const_iterator last_after =
571 std::upper_bound(first, values_.end(), ma);
572 const int64_t new_index_max = (last_after - values_.begin()) - 1;
573
574 // Assign.
575 index_->SetRange(new_index_min, new_index_max);
576}
577
578// ----- Solver::MakeElement(int array, int var) -----
579IntExpr* BuildElement(Solver* const solver, const std::vector<int64_t>& values,
580 IntVar* const index) {
581 // Various checks.
582 // Is array constant?
583 if (IsArrayConstant(values, values[0])) {
584 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
585 return solver->MakeIntConst(values[0]);
586 }
587 // Is array built with booleans only?
588 // TODO(user): We could maintain the index of the first one.
589 if (IsArrayBoolean(values)) {
590 std::vector<int64_t> ones;
591 int first_zero = -1;
592 for (int i = 0; i < values.size(); ++i) {
593 if (values[i] == 1) {
594 ones.push_back(i);
595 } else {
596 first_zero = i;
597 }
598 }
599 if (ones.size() == 1) {
600 DCHECK_EQ(int64_t{1}, values[ones.back()]);
601 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
602 return solver->MakeIsEqualCstVar(index, ones.back());
603 } else if (ones.size() == values.size() - 1) {
604 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
605 return solver->MakeIsDifferentCstVar(index, first_zero);
606 } else if (ones.size() == ones.back() - ones.front() + 1) { // contiguous.
607 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
608 IntVar* const b = solver->MakeBoolVar("ContiguousBooleanElementVar");
609 solver->AddConstraint(
610 solver->MakeIsBetweenCt(index, ones.front(), ones.back(), b));
611 return b;
612 } else {
613 IntVar* const b = solver->MakeBoolVar("NonContiguousBooleanElementVar");
614 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
615 solver->AddConstraint(solver->MakeIsMemberCt(index, ones, b));
616 return b;
617 }
618 }
619 IntExpr* cache = nullptr;
620 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
621 cache = solver->Cache()->FindVarConstantArrayExpression(
622 index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
623 }
624 if (cache != nullptr) {
625 return cache;
626 } else {
627 IntExpr* result = nullptr;
628 if (values.size() >= 2 && index->Min() == 0 && index->Max() == 1) {
629 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
630 values[0]);
631 } else if (values.size() == 2 && index->Contains(0) && index->Contains(1)) {
632 solver->AddConstraint(solver->MakeBetweenCt(index, 0, 1));
633 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
634 values[0]);
635 } else if (IsIncreasingContiguous(values)) {
636 result = solver->MakeSum(index, values[0]);
637 } else if (IsIncreasing(values)) {
638 result = solver->RegisterIntExpr(solver->RevAlloc(
639 new IncreasingIntExprElement(solver, values, index)));
640 } else {
641 if (solver->parameters().use_element_rmq()) {
642 result = solver->RegisterIntExpr(solver->RevAlloc(
643 new RangeMinimumQueryExprElement(solver, values, index)));
644 } else {
645 result = solver->RegisterIntExpr(
646 solver->RevAlloc(new IntExprElement(solver, values, index)));
647 }
648 }
649 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
650 solver->Cache()->InsertVarConstantArrayExpression(
651 result, index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
652 }
653 return result;
654 }
655}
656} // namespace
657
658IntExpr* Solver::MakeElement(const std::vector<int64_t>& values,
659 IntVar* const index) {
660 DCHECK(index);
661 DCHECK_EQ(this, index->solver());
662 if (index->Bound()) {
663 return MakeIntConst(values[index->Min()]);
664 }
665 return BuildElement(this, values, index);
666}
667
668IntExpr* Solver::MakeElement(const std::vector<int>& values,
669 IntVar* const index) {
670 DCHECK(index);
671 DCHECK_EQ(this, index->solver());
672 if (index->Bound()) {
673 return MakeIntConst(values[index->Min()]);
674 }
675 return BuildElement(this, ToInt64Vector(values), index);
676}
677
678// ----- IntExprFunctionElement -----
679
680namespace {
681class IntExprFunctionElement : public BaseIntExprElement {
682 public:
683 IntExprFunctionElement(Solver* s, Solver::IndexEvaluator1 values, IntVar* e);
684 ~IntExprFunctionElement() override;
685
686 std::string name() const override {
687 return absl::StrFormat("IntFunctionElement(%s)", expr_->name());
688 }
689
690 std::string DebugString() const override {
691 return absl::StrFormat("IntFunctionElement(%s)", expr_->DebugString());
692 }
693
694 void Accept(ModelVisitor* const visitor) const override {
695 // Warning: This will expand all values into a vector.
696 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
697 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
698 expr_);
699 visitor->VisitInt64ToInt64Extension(values_, expr_->Min(), expr_->Max());
700 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
701 }
702
703 protected:
704 int64_t ElementValue(int index) const override { return values_(index); }
705 int64_t ExprMin() const override { return expr_->Min(); }
706 int64_t ExprMax() const override { return expr_->Max(); }
707
708 private:
709 Solver::IndexEvaluator1 values_;
710};
711
712IntExprFunctionElement::IntExprFunctionElement(Solver* const s,
713 Solver::IndexEvaluator1 values,
714 IntVar* const e)
715 : BaseIntExprElement(s, e), values_(std::move(values)) {
716 CHECK(values_ != nullptr);
717}
718
719IntExprFunctionElement::~IntExprFunctionElement() {}
720
721// ----- Increasing Element -----
722
723class IncreasingIntExprFunctionElement : public BaseIntExpr {
724 public:
725 IncreasingIntExprFunctionElement(Solver* const s,
726 Solver::IndexEvaluator1 values,
727 IntVar* const index)
728 : BaseIntExpr(s), values_(std::move(values)), index_(index) {
729 DCHECK(values_ != nullptr);
730 DCHECK(index);
731 DCHECK(s);
732 }
733
734 ~IncreasingIntExprFunctionElement() override {}
735
736 int64_t Min() const override { return values_(index_->Min()); }
737
738 void SetMin(int64_t m) override {
739 const int64_t index_min = index_->Min();
740 const int64_t index_max = index_->Max();
741 if (m > values_(index_max)) {
742 solver()->Fail();
743 }
744 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, m);
745 index_->SetMin(new_index_min);
746 }
747
748 int64_t Max() const override { return values_(index_->Max()); }
749
750 void SetMax(int64_t m) override {
751 int64_t index_min = index_->Min();
752 int64_t index_max = index_->Max();
753 if (m < values_(index_min)) {
754 solver()->Fail();
755 }
756 const int64_t new_index_max = FindNewIndexMax(index_min, index_max, m);
757 index_->SetMax(new_index_max);
758 }
759
760 void SetRange(int64_t mi, int64_t ma) override {
761 const int64_t index_min = index_->Min();
762 const int64_t index_max = index_->Max();
763 const int64_t value_min = values_(index_min);
764 const int64_t value_max = values_(index_max);
765 if (mi > ma || ma < value_min || mi > value_max) {
766 solver()->Fail();
767 }
768 if (mi <= value_min && ma >= value_max) {
769 // Nothing to do.
770 return;
771 }
772
773 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, mi);
774 const int64_t new_index_max = FindNewIndexMax(new_index_min, index_max, ma);
775 // Assign.
776 index_->SetRange(new_index_min, new_index_max);
777 }
778
779 std::string name() const override {
780 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
781 index_->name());
782 }
783
784 std::string DebugString() const override {
785 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
786 index_->DebugString());
787 }
788
789 void WhenRange(Demon* d) override { index_->WhenRange(d); }
790
791 void Accept(ModelVisitor* const visitor) const override {
792 // Warning: This will expand all values into a vector.
793 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
794 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
795 index_);
796 if (index_->Min() == 0) {
797 visitor->VisitInt64ToInt64AsArray(values_, ModelVisitor::kValuesArgument,
798 index_->Max());
799 } else {
800 visitor->VisitInt64ToInt64Extension(values_, index_->Min(),
801 index_->Max());
802 }
803 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
804 }
805
806 private:
807 int64_t FindNewIndexMin(int64_t index_min, int64_t index_max, int64_t m) {
808 if (m <= values_(index_min)) {
809 return index_min;
810 }
811
812 DCHECK_LT(values_(index_min), m);
813 DCHECK_GE(values_(index_max), m);
814
815 int64_t index_lower_bound = index_min;
816 int64_t index_upper_bound = index_max;
817 while (index_upper_bound - index_lower_bound > 1) {
818 DCHECK_LT(values_(index_lower_bound), m);
819 DCHECK_GE(values_(index_upper_bound), m);
820 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
821 const int64_t pivot_value = values_(pivot);
822 if (pivot_value < m) {
823 index_lower_bound = pivot;
824 } else {
825 index_upper_bound = pivot;
826 }
827 }
828 DCHECK(values_(index_upper_bound) >= m);
829 return index_upper_bound;
830 }
831
832 int64_t FindNewIndexMax(int64_t index_min, int64_t index_max, int64_t m) {
833 if (m >= values_(index_max)) {
834 return index_max;
835 }
836
837 DCHECK_LE(values_(index_min), m);
838 DCHECK_GT(values_(index_max), m);
839
840 int64_t index_lower_bound = index_min;
841 int64_t index_upper_bound = index_max;
842 while (index_upper_bound - index_lower_bound > 1) {
843 DCHECK_LE(values_(index_lower_bound), m);
844 DCHECK_GT(values_(index_upper_bound), m);
845 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
846 const int64_t pivot_value = values_(pivot);
847 if (pivot_value > m) {
848 index_upper_bound = pivot;
849 } else {
850 index_lower_bound = pivot;
851 }
852 }
853 DCHECK(values_(index_lower_bound) <= m);
854 return index_lower_bound;
855 }
856
857 Solver::IndexEvaluator1 values_;
858 IntVar* const index_;
859};
860} // namespace
861
863 IntVar* const index) {
864 CHECK_EQ(this, index->solver());
865 return RegisterIntExpr(
866 RevAlloc(new IntExprFunctionElement(this, std::move(values), index)));
867}
868
870 bool increasing, IntVar* const index) {
871 CHECK_EQ(this, index->solver());
872 if (increasing) {
873 return RegisterIntExpr(
874 RevAlloc(new IncreasingIntExprFunctionElement(this, values, index)));
875 } else {
876 // You need to pass by copy such that opposite_value does not include a
877 // dandling reference when leaving this scope.
878 Solver::IndexEvaluator1 opposite_values = [values](int64_t i) {
879 return -values(i);
880 };
881 return RegisterIntExpr(MakeOpposite(RevAlloc(
882 new IncreasingIntExprFunctionElement(this, opposite_values, index))));
883 }
884}
885
886// ----- IntIntExprFunctionElement -----
887
888namespace {
889class IntIntExprFunctionElement : public BaseIntExpr {
890 public:
891 IntIntExprFunctionElement(Solver* s, Solver::IndexEvaluator2 values,
892 IntVar* expr1, IntVar* expr2);
893 ~IntIntExprFunctionElement() override;
894 std::string DebugString() const override {
895 return absl::StrFormat("IntIntFunctionElement(%s,%s)",
896 expr1_->DebugString(), expr2_->DebugString());
897 }
898 int64_t Min() const override;
899 int64_t Max() const override;
900 void Range(int64_t* lower_bound, int64_t* upper_bound) override;
901 void SetMin(int64_t lower_bound) override;
902 void SetMax(int64_t upper_bound) override;
903 void SetRange(int64_t lower_bound, int64_t upper_bound) override;
904 bool Bound() const override { return expr1_->Bound() && expr2_->Bound(); }
905 // TODO(user) : improve me, the previous test is not always true
906 void WhenRange(Demon* d) override {
907 expr1_->WhenRange(d);
908 expr2_->WhenRange(d);
909 }
910
911 void Accept(ModelVisitor* const visitor) const override {
912 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
913 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
914 expr1_);
915 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndex2Argument,
916 expr2_);
917 // Warning: This will expand all values into a vector.
918 const int64_t expr1_min = expr1_->Min();
919 const int64_t expr1_max = expr1_->Max();
920 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, expr1_min);
921 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, expr1_max);
922 for (int i = expr1_min; i <= expr1_max; ++i) {
923 visitor->VisitInt64ToInt64Extension(
924 [this, i](int64_t j) { return values_(i, j); }, expr2_->Min(),
925 expr2_->Max());
926 }
927 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
928 }
929
930 private:
931 int64_t ElementValue(int index1, int index2) const {
932 return values_(index1, index2);
933 }
934 void UpdateSupports() const;
935
936 IntVar* const expr1_;
937 IntVar* const expr2_;
938 mutable int64_t min_;
939 mutable int min_support1_;
940 mutable int min_support2_;
941 mutable int64_t max_;
942 mutable int max_support1_;
943 mutable int max_support2_;
944 mutable bool initial_update_;
945 Solver::IndexEvaluator2 values_;
946 IntVarIterator* const expr1_iterator_;
947 IntVarIterator* const expr2_iterator_;
948};
949
950IntIntExprFunctionElement::IntIntExprFunctionElement(
951 Solver* const s, Solver::IndexEvaluator2 values, IntVar* const expr1,
952 IntVar* const expr2)
953 : BaseIntExpr(s),
954 expr1_(expr1),
955 expr2_(expr2),
956 min_(0),
957 min_support1_(-1),
958 min_support2_(-1),
959 max_(0),
960 max_support1_(-1),
961 max_support2_(-1),
962 initial_update_(true),
963 values_(std::move(values)),
964 expr1_iterator_(expr1_->MakeDomainIterator(true)),
965 expr2_iterator_(expr2_->MakeDomainIterator(true)) {
966 CHECK(values_ != nullptr);
967}
968
969IntIntExprFunctionElement::~IntIntExprFunctionElement() {}
970
971int64_t IntIntExprFunctionElement::Min() const {
972 UpdateSupports();
973 return min_;
974}
975
976int64_t IntIntExprFunctionElement::Max() const {
977 UpdateSupports();
978 return max_;
979}
980
981void IntIntExprFunctionElement::Range(int64_t* lower_bound,
982 int64_t* upper_bound) {
983 UpdateSupports();
984 *lower_bound = min_;
985 *upper_bound = max_;
986}
987
988#define UPDATE_ELEMENT_INDEX_BOUNDS(test) \
989 const int64_t emin1 = expr1_->Min(); \
990 const int64_t emax1 = expr1_->Max(); \
991 const int64_t emin2 = expr2_->Min(); \
992 const int64_t emax2 = expr2_->Max(); \
993 int64_t nmin1 = emin1; \
994 bool found = false; \
995 while (nmin1 <= emax1 && !found) { \
996 for (int i = emin2; i <= emax2; ++i) { \
997 int64_t value = ElementValue(nmin1, i); \
998 if (test) { \
999 found = true; \
1000 break; \
1001 } \
1002 } \
1003 if (!found) { \
1004 nmin1++; \
1005 } \
1006 } \
1007 if (nmin1 > emax1) { \
1008 solver()->Fail(); \
1009 } \
1010 int64_t nmin2 = emin2; \
1011 found = false; \
1012 while (nmin2 <= emax2 && !found) { \
1013 for (int i = emin1; i <= emax1; ++i) { \
1014 int64_t value = ElementValue(i, nmin2); \
1015 if (test) { \
1016 found = true; \
1017 break; \
1018 } \
1019 } \
1020 if (!found) { \
1021 nmin2++; \
1022 } \
1023 } \
1024 if (nmin2 > emax2) { \
1025 solver()->Fail(); \
1026 } \
1027 int64_t nmax1 = emax1; \
1028 found = false; \
1029 while (nmax1 >= nmin1 && !found) { \
1030 for (int i = emin2; i <= emax2; ++i) { \
1031 int64_t value = ElementValue(nmax1, i); \
1032 if (test) { \
1033 found = true; \
1034 break; \
1035 } \
1036 } \
1037 if (!found) { \
1038 nmax1--; \
1039 } \
1040 } \
1041 int64_t nmax2 = emax2; \
1042 found = false; \
1043 while (nmax2 >= nmin2 && !found) { \
1044 for (int i = emin1; i <= emax1; ++i) { \
1045 int64_t value = ElementValue(i, nmax2); \
1046 if (test) { \
1047 found = true; \
1048 break; \
1049 } \
1050 } \
1051 if (!found) { \
1052 nmax2--; \
1053 } \
1054 } \
1055 expr1_->SetRange(nmin1, nmax1); \
1056 expr2_->SetRange(nmin2, nmax2);
1057
1058void IntIntExprFunctionElement::SetMin(int64_t lower_bound) {
1060}
1061
1062void IntIntExprFunctionElement::SetMax(int64_t upper_bound) {
1064}
1065
1066void IntIntExprFunctionElement::SetRange(int64_t lower_bound,
1067 int64_t upper_bound) {
1068 if (lower_bound > upper_bound) {
1069 solver()->Fail();
1070 }
1072}
1073
1074#undef UPDATE_ELEMENT_INDEX_BOUNDS
1075
1076void IntIntExprFunctionElement::UpdateSupports() const {
1077 if (initial_update_ || !expr1_->Contains(min_support1_) ||
1078 !expr1_->Contains(max_support1_) || !expr2_->Contains(min_support2_) ||
1079 !expr2_->Contains(max_support2_)) {
1080 const int64_t emax1 = expr1_->Max();
1081 const int64_t emax2 = expr2_->Max();
1082 int64_t min_value = ElementValue(emax1, emax2);
1083 int64_t max_value = min_value;
1084 int min_support1 = emax1;
1085 int max_support1 = emax1;
1086 int min_support2 = emax2;
1087 int max_support2 = emax2;
1088 for (const int64_t index1 : InitAndGetValues(expr1_iterator_)) {
1089 for (const int64_t index2 : InitAndGetValues(expr2_iterator_)) {
1090 const int64_t value = ElementValue(index1, index2);
1091 if (value > max_value) {
1092 max_value = value;
1093 max_support1 = index1;
1094 max_support2 = index2;
1095 } else if (value < min_value) {
1096 min_value = value;
1097 min_support1 = index1;
1098 min_support2 = index2;
1099 }
1100 }
1101 }
1102 Solver* s = solver();
1103 s->SaveAndSetValue(&min_, min_value);
1104 s->SaveAndSetValue(&min_support1_, min_support1);
1105 s->SaveAndSetValue(&min_support2_, min_support2);
1106 s->SaveAndSetValue(&max_, max_value);
1107 s->SaveAndSetValue(&max_support1_, max_support1);
1108 s->SaveAndSetValue(&max_support2_, max_support2);
1109 s->SaveAndSetValue(&initial_update_, false);
1110 }
1111}
1112} // namespace
1113
1115 IntVar* const index1, IntVar* const index2) {
1116 CHECK_EQ(this, index1->solver());
1117 CHECK_EQ(this, index2->solver());
1118 return RegisterIntExpr(RevAlloc(
1119 new IntIntExprFunctionElement(this, std::move(values), index1, index2)));
1120}
1121
1122// ---------- Generalized element ----------
1123
1124// ----- IfThenElseCt -----
1125
1127 public:
1128 IfThenElseCt(Solver* const solver, IntVar* const condition,
1129 IntExpr* const one, IntExpr* const zero, IntVar* const target)
1130 : CastConstraint(solver, target),
1131 condition_(condition),
1132 zero_(zero),
1133 one_(one) {}
1134
1135 ~IfThenElseCt() override {}
1136
1137 void Post() override {
1138 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1139 condition_->WhenBound(demon);
1140 one_->WhenRange(demon);
1141 zero_->WhenRange(demon);
1142 target_var_->WhenRange(demon);
1143 }
1144
1145 void InitialPropagate() override {
1146 condition_->SetRange(0, 1);
1147 const int64_t target_var_min = target_var_->Min();
1148 const int64_t target_var_max = target_var_->Max();
1149 int64_t new_min = std::numeric_limits<int64_t>::min();
1150 int64_t new_max = std::numeric_limits<int64_t>::max();
1151 if (condition_->Max() == 0) {
1152 zero_->SetRange(target_var_min, target_var_max);
1153 zero_->Range(&new_min, &new_max);
1154 } else if (condition_->Min() == 1) {
1155 one_->SetRange(target_var_min, target_var_max);
1156 one_->Range(&new_min, &new_max);
1157 } else {
1158 if (target_var_max < zero_->Min() || target_var_min > zero_->Max()) {
1159 condition_->SetValue(1);
1160 one_->SetRange(target_var_min, target_var_max);
1161 one_->Range(&new_min, &new_max);
1162 } else if (target_var_max < one_->Min() || target_var_min > one_->Max()) {
1163 condition_->SetValue(0);
1164 zero_->SetRange(target_var_min, target_var_max);
1165 zero_->Range(&new_min, &new_max);
1166 } else {
1167 int64_t zl = 0;
1168 int64_t zu = 0;
1169 int64_t ol = 0;
1170 int64_t ou = 0;
1171 zero_->Range(&zl, &zu);
1172 one_->Range(&ol, &ou);
1173 new_min = std::min(zl, ol);
1174 new_max = std::max(zu, ou);
1175 }
1176 }
1177 target_var_->SetRange(new_min, new_max);
1178 }
1179
1180 std::string DebugString() const override {
1181 return absl::StrFormat("(%s ? %s : %s) == %s", condition_->DebugString(),
1182 one_->DebugString(), zero_->DebugString(),
1183 target_var_->DebugString());
1184 }
1185
1186 void Accept(ModelVisitor* const visitor) const override {}
1187
1188 private:
1189 IntVar* const condition_;
1190 IntExpr* const zero_;
1191 IntExpr* const one_;
1192};
1193
1194// ----- IntExprEvaluatorElementCt -----
1195
1196// This constraint implements evaluator(index) == var. It is delayed such
1197// that propagation only occurs when all variables have been touched.
1198// The range of the evaluator is [range_start, range_end).
1199
1200namespace {
1201class IntExprEvaluatorElementCt : public CastConstraint {
1202 public:
1203 IntExprEvaluatorElementCt(Solver* s, Solver::Int64ToIntVar evaluator,
1204 int64_t range_start, int64_t range_end,
1205 IntVar* index, IntVar* target_var);
1206 ~IntExprEvaluatorElementCt() override {}
1207
1208 void Post() override;
1209 void InitialPropagate() override;
1210
1211 void Propagate();
1212 void Update(int index);
1213 void UpdateExpr();
1214
1215 std::string DebugString() const override;
1216 void Accept(ModelVisitor* visitor) const override;
1217
1218 protected:
1219 IntVar* const index_;
1220
1221 private:
1222 const Solver::Int64ToIntVar evaluator_;
1223 const int64_t range_start_;
1224 const int64_t range_end_;
1225 int min_support_;
1226 int max_support_;
1227};
1228
1229IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(
1230 Solver* const s, Solver::Int64ToIntVar evaluator, int64_t range_start,
1231 int64_t range_end, IntVar* const index, IntVar* const target_var)
1232 : CastConstraint(s, target_var),
1233 index_(index),
1234 evaluator_(std::move(evaluator)),
1235 range_start_(range_start),
1236 range_end_(range_end),
1237 min_support_(-1),
1238 max_support_(-1) {}
1239
1240void IntExprEvaluatorElementCt::Post() {
1241 Demon* const delayed_propagate_demon = MakeDelayedConstraintDemon0(
1242 solver(), this, &IntExprEvaluatorElementCt::Propagate, "Propagate");
1243 for (int i = range_start_; i < range_end_; ++i) {
1244 IntVar* const current_var = evaluator_(i);
1245 current_var->WhenRange(delayed_propagate_demon);
1246 Demon* const update_demon = MakeConstraintDemon1(
1247 solver(), this, &IntExprEvaluatorElementCt::Update, "Update", i);
1248 current_var->WhenRange(update_demon);
1249 }
1250 index_->WhenRange(delayed_propagate_demon);
1251 Demon* const update_expr_demon = MakeConstraintDemon0(
1252 solver(), this, &IntExprEvaluatorElementCt::UpdateExpr, "UpdateExpr");
1253 index_->WhenRange(update_expr_demon);
1254 Demon* const update_var_demon = MakeConstraintDemon0(
1255 solver(), this, &IntExprEvaluatorElementCt::Propagate, "UpdateVar");
1256
1257 target_var_->WhenRange(update_var_demon);
1258}
1259
1260void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }
1261
1262void IntExprEvaluatorElementCt::Propagate() {
1263 const int64_t emin = std::max(range_start_, index_->Min());
1264 const int64_t emax = std::min<int64_t>(range_end_ - 1, index_->Max());
1265 const int64_t vmin = target_var_->Min();
1266 const int64_t vmax = target_var_->Max();
1267 if (emin == emax) {
1268 index_->SetValue(emin); // in case it was reduced by the above min/max.
1269 evaluator_(emin)->SetRange(vmin, vmax);
1270 } else {
1271 int64_t nmin = emin;
1272 for (; nmin <= emax; nmin++) {
1273 // break if the intersection of
1274 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1275 // is non-empty.
1276 IntVar* const nmin_var = evaluator_(nmin);
1277 if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin) break;
1278 }
1279 int64_t nmax = emax;
1280 for (; nmin <= nmax; nmax--) {
1281 // break if the intersection of
1282 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1283 // is non-empty.
1284 IntExpr* const nmax_var = evaluator_(nmax);
1285 if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin) break;
1286 }
1287 index_->SetRange(nmin, nmax);
1288 if (nmin == nmax) {
1289 evaluator_(nmin)->SetRange(vmin, vmax);
1290 }
1291 }
1292 if (min_support_ == -1 || max_support_ == -1) {
1293 int min_support = -1;
1294 int max_support = -1;
1295 int64_t gmin = std::numeric_limits<int64_t>::max();
1296 int64_t gmax = std::numeric_limits<int64_t>::min();
1297 for (int i = index_->Min(); i <= index_->Max(); ++i) {
1298 IntExpr* const var_i = evaluator_(i);
1299 const int64_t vmin = var_i->Min();
1300 if (vmin < gmin) {
1301 gmin = vmin;
1302 }
1303 const int64_t vmax = var_i->Max();
1304 if (vmax > gmax) {
1305 gmax = vmax;
1306 }
1307 }
1308 solver()->SaveAndSetValue(&min_support_, min_support);
1309 solver()->SaveAndSetValue(&max_support_, max_support);
1310 target_var_->SetRange(gmin, gmax);
1311 }
1312}
1313
1314void IntExprEvaluatorElementCt::Update(int index) {
1315 if (index == min_support_ || index == max_support_) {
1316 solver()->SaveAndSetValue(&min_support_, -1);
1317 solver()->SaveAndSetValue(&max_support_, -1);
1318 }
1319}
1320
1321void IntExprEvaluatorElementCt::UpdateExpr() {
1322 if (!index_->Contains(min_support_) || !index_->Contains(max_support_)) {
1323 solver()->SaveAndSetValue(&min_support_, -1);
1324 solver()->SaveAndSetValue(&max_support_, -1);
1325 }
1326}
1327
1328namespace {
1329std::string StringifyEvaluatorBare(const Solver::Int64ToIntVar& evaluator,
1330 int64_t range_start, int64_t range_end) {
1331 std::string out;
1332 for (int64_t i = range_start; i < range_end; ++i) {
1333 if (i != range_start) {
1334 out += ", ";
1335 }
1336 out += absl::StrFormat("%d -> %s", i, evaluator(i)->DebugString());
1337 }
1338 return out;
1339}
1340
1341std::string StringifyInt64ToIntVar(const Solver::Int64ToIntVar& evaluator,
1342 int64_t range_begin, int64_t range_end) {
1343 std::string out;
1344 if (range_end - range_begin > 10) {
1345 out = absl::StrFormat(
1346 "IntToIntVar(%s, ...%s)",
1347 StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),
1348 StringifyEvaluatorBare(evaluator, range_end - 5, range_end));
1349 } else {
1350 out = absl::StrFormat(
1351 "IntToIntVar(%s)",
1352 StringifyEvaluatorBare(evaluator, range_begin, range_end));
1353 }
1354 return out;
1355}
1356} // namespace
1357
1358std::string IntExprEvaluatorElementCt::DebugString() const {
1359 return StringifyInt64ToIntVar(evaluator_, range_start_, range_end_);
1360}
1361
1362void IntExprEvaluatorElementCt::Accept(ModelVisitor* const visitor) const {
1363 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1364 visitor->VisitIntegerVariableEvaluatorArgument(
1365 ModelVisitor::kEvaluatorArgument, evaluator_);
1366 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1367 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1368 target_var_);
1369 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1370}
1371
1372// ----- IntExprArrayElementCt -----
1373
1374// This constraint implements vars[index] == var. It is delayed such
1375// that propagation only occurs when all variables have been touched.
1376
1377class IntExprArrayElementCt : public IntExprEvaluatorElementCt {
1378 public:
1379 IntExprArrayElementCt(Solver* s, std::vector<IntVar*> vars, IntVar* index,
1380 IntVar* target_var);
1381
1382 std::string DebugString() const override;
1383 void Accept(ModelVisitor* visitor) const override;
1384
1385 private:
1386 const std::vector<IntVar*> vars_;
1387};
1388
1389IntExprArrayElementCt::IntExprArrayElementCt(Solver* const s,
1390 std::vector<IntVar*> vars,
1391 IntVar* const index,
1392 IntVar* const target_var)
1393 : IntExprEvaluatorElementCt(
1394 s, [this](int64_t idx) { return vars_[idx]; }, 0, vars.size(), index,
1395 target_var),
1396 vars_(std::move(vars)) {}
1397
1398std::string IntExprArrayElementCt::DebugString() const {
1399 int64_t size = vars_.size();
1400 if (size > 10) {
1401 return absl::StrFormat(
1402 "IntExprArrayElement(var array of size %d, %s) == %s", size,
1403 index_->DebugString(), target_var_->DebugString());
1404 } else {
1405 return absl::StrFormat("IntExprArrayElement([%s], %s) == %s",
1407 index_->DebugString(), target_var_->DebugString());
1408 }
1409}
1410
1411void IntExprArrayElementCt::Accept(ModelVisitor* const visitor) const {
1412 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1413 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1414 vars_);
1415 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1416 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1417 target_var_);
1418 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1419}
1420
1421// ----- IntExprArrayElementCstCt -----
1422
1423// This constraint implements vars[index] == constant.
1424
1425class IntExprArrayElementCstCt : public Constraint {
1426 public:
1427 IntExprArrayElementCstCt(Solver* const s, const std::vector<IntVar*>& vars,
1428 IntVar* const index, int64_t target)
1429 : Constraint(s),
1430 vars_(vars),
1431 index_(index),
1432 target_(target),
1433 demons_(vars.size()) {}
1434
1435 ~IntExprArrayElementCstCt() override {}
1436
1437 void Post() override {
1438 for (int i = 0; i < vars_.size(); ++i) {
1439 demons_[i] = MakeConstraintDemon1(
1440 solver(), this, &IntExprArrayElementCstCt::Propagate, "Propagate", i);
1441 vars_[i]->WhenDomain(demons_[i]);
1442 }
1443 Demon* const index_demon = MakeConstraintDemon0(
1444 solver(), this, &IntExprArrayElementCstCt::PropagateIndex,
1445 "PropagateIndex");
1446 index_->WhenBound(index_demon);
1447 }
1448
1449 void InitialPropagate() override {
1450 for (int i = 0; i < vars_.size(); ++i) {
1451 Propagate(i);
1452 }
1453 PropagateIndex();
1454 }
1455
1456 void Propagate(int index) {
1457 if (!vars_[index]->Contains(target_)) {
1458 index_->RemoveValue(index);
1459 demons_[index]->inhibit(solver());
1460 }
1461 }
1462
1463 void PropagateIndex() {
1464 if (index_->Bound()) {
1465 vars_[index_->Min()]->SetValue(target_);
1466 }
1467 }
1468
1469 std::string DebugString() const override {
1470 return absl::StrFormat("IntExprArrayElement([%s], %s) == %d",
1472 index_->DebugString(), target_);
1473 }
1474
1475 void Accept(ModelVisitor* const visitor) const override {
1476 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1477 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1478 vars_);
1479 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1480 index_);
1481 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1482 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1483 }
1484
1485 private:
1486 const std::vector<IntVar*> vars_;
1487 IntVar* const index_;
1488 const int64_t target_;
1489 std::vector<Demon*> demons_;
1490};
1491
1492// This constraint implements index == position(constant in vars).
1493
1494class IntExprIndexOfCt : public Constraint {
1495 public:
1496 IntExprIndexOfCt(Solver* const s, const std::vector<IntVar*>& vars,
1497 IntVar* const index, int64_t target)
1498 : Constraint(s),
1499 vars_(vars),
1500 index_(index),
1501 target_(target),
1502 demons_(vars_.size()),
1503 index_iterator_(index->MakeHoleIterator(true)) {}
1504
1505 ~IntExprIndexOfCt() override {}
1506
1507 void Post() override {
1508 for (int i = 0; i < vars_.size(); ++i) {
1509 demons_[i] = MakeConstraintDemon1(
1510 solver(), this, &IntExprIndexOfCt::Propagate, "Propagate", i);
1511 vars_[i]->WhenDomain(demons_[i]);
1512 }
1513 Demon* const index_demon = MakeConstraintDemon0(
1514 solver(), this, &IntExprIndexOfCt::PropagateIndex, "PropagateIndex");
1515 index_->WhenDomain(index_demon);
1516 }
1517
1518 void InitialPropagate() override {
1519 for (int i = 0; i < vars_.size(); ++i) {
1520 if (!index_->Contains(i)) {
1521 vars_[i]->RemoveValue(target_);
1522 } else if (!vars_[i]->Contains(target_)) {
1523 index_->RemoveValue(i);
1524 demons_[i]->inhibit(solver());
1525 } else if (vars_[i]->Bound()) {
1526 index_->SetValue(i);
1527 demons_[i]->inhibit(solver());
1528 }
1529 }
1530 }
1531
1532 void Propagate(int index) {
1533 if (!vars_[index]->Contains(target_)) {
1534 index_->RemoveValue(index);
1535 demons_[index]->inhibit(solver());
1536 } else if (vars_[index]->Bound()) {
1537 index_->SetValue(index);
1538 }
1539 }
1540
1541 void PropagateIndex() {
1542 const int64_t oldmax = index_->OldMax();
1543 const int64_t vmin = index_->Min();
1544 const int64_t vmax = index_->Max();
1545 for (int64_t value = index_->OldMin(); value < vmin; ++value) {
1546 vars_[value]->RemoveValue(target_);
1547 demons_[value]->inhibit(solver());
1548 }
1549 for (const int64_t value : InitAndGetValues(index_iterator_)) {
1550 vars_[value]->RemoveValue(target_);
1551 demons_[value]->inhibit(solver());
1552 }
1553 for (int64_t value = vmax + 1; value <= oldmax; ++value) {
1554 vars_[value]->RemoveValue(target_);
1555 demons_[value]->inhibit(solver());
1556 }
1557 if (index_->Bound()) {
1558 vars_[index_->Min()]->SetValue(target_);
1559 }
1560 }
1561
1562 std::string DebugString() const override {
1563 return absl::StrFormat("IntExprIndexOf([%s], %s) == %d",
1565 index_->DebugString(), target_);
1566 }
1567
1568 void Accept(ModelVisitor* const visitor) const override {
1569 visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);
1570 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1571 vars_);
1572 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1573 index_);
1574 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1575 visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);
1576 }
1577
1578 private:
1579 const std::vector<IntVar*> vars_;
1580 IntVar* const index_;
1581 const int64_t target_;
1582 std::vector<Demon*> demons_;
1583 IntVarIterator* const index_iterator_;
1584};
1585
1586// Factory helper.
1587
1588Constraint* MakeElementEqualityFunc(Solver* const solver,
1589 const std::vector<int64_t>& vals,
1590 IntVar* const index, IntVar* const target) {
1591 if (index->Bound()) {
1592 const int64_t val = index->Min();
1593 if (val < 0 || val >= vals.size()) {
1594 return solver->MakeFalseConstraint();
1595 } else {
1596 return solver->MakeEquality(target, vals[val]);
1597 }
1598 } else {
1599 if (IsIncreasingContiguous(vals)) {
1600 return solver->MakeEquality(target, solver->MakeSum(index, vals[0]));
1601 } else {
1602 return solver->RevAlloc(
1603 new IntElementConstraint(solver, vals, index, target));
1604 }
1605 }
1606}
1607} // namespace
1608
1610 IntExpr* const then_expr,
1611 IntExpr* const else_expr,
1612 IntVar* const target_var) {
1613 return RevAlloc(
1614 new IfThenElseCt(this, condition, then_expr, else_expr, target_var));
1615}
1616
1617IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars,
1618 IntVar* const index) {
1619 if (index->Bound()) {
1620 return vars[index->Min()];
1621 }
1622 const int size = vars.size();
1623 if (AreAllBound(vars)) {
1624 std::vector<int64_t> values(size);
1625 for (int i = 0; i < size; ++i) {
1626 values[i] = vars[i]->Value();
1627 }
1628 return MakeElement(values, index);
1629 }
1630 if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
1631 index->Min() >= 0 && index->Max() < vars.size()) {
1632 // Let's get the index between 0 and 1.
1633 IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
1634 IntVar* const zero = vars[index->Min()];
1635 IntVar* const one = vars[index->Max()];
1636 const std::string name = absl::StrFormat(
1637 "ElementVar([%s], %s)", JoinNamePtr(vars, ", "), index->name());
1638 IntVar* const target = MakeIntVar(std::min(zero->Min(), one->Min()),
1639 std::max(zero->Max(), one->Max()), name);
1640 AddConstraint(
1641 RevAlloc(new IfThenElseCt(this, scaled_index, one, zero, target)));
1642 return target;
1643 }
1644 int64_t emin = std::numeric_limits<int64_t>::max();
1645 int64_t emax = std::numeric_limits<int64_t>::min();
1646 std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
1647 for (const int64_t index_value : InitAndGetValues(iterator.get())) {
1648 if (index_value >= 0 && index_value < size) {
1649 emin = std::min(emin, vars[index_value]->Min());
1650 emax = std::max(emax, vars[index_value]->Max());
1651 }
1652 }
1653 const std::string vname =
1654 size > 10 ? absl::StrFormat("ElementVar(var array of size %d, %s)", size,
1655 index->DebugString())
1656 : absl::StrFormat("ElementVar([%s], %s)",
1657 JoinNamePtr(vars, ", "), index->name());
1658 IntVar* const element_var = MakeIntVar(emin, emax, vname);
1659 AddConstraint(
1660 RevAlloc(new IntExprArrayElementCt(this, vars, index, element_var)));
1661 return element_var;
1662}
1663
1664IntExpr* Solver::MakeElement(Int64ToIntVar vars, int64_t range_start,
1665 int64_t range_end, IntVar* argument) {
1666 const std::string index_name =
1667 !argument->name().empty() ? argument->name() : argument->DebugString();
1668 const std::string vname = absl::StrFormat(
1669 "ElementVar(%s, %s)",
1670 StringifyInt64ToIntVar(vars, range_start, range_end), index_name);
1671 IntVar* const element_var =
1672 MakeIntVar(std::numeric_limits<int64_t>::min(),
1673 std::numeric_limits<int64_t>::max(), vname);
1674 IntExprEvaluatorElementCt* evaluation_ct = new IntExprEvaluatorElementCt(
1675 this, std::move(vars), range_start, range_end, argument, element_var);
1676 AddConstraint(RevAlloc(evaluation_ct));
1677 evaluation_ct->Propagate();
1678 return element_var;
1679}
1680
1681Constraint* Solver::MakeElementEquality(const std::vector<int64_t>& vals,
1682 IntVar* const index,
1683 IntVar* const target) {
1684 return MakeElementEqualityFunc(this, vals, index, target);
1685}
1686
1687Constraint* Solver::MakeElementEquality(const std::vector<int>& vals,
1688 IntVar* const index,
1689 IntVar* const target) {
1690 return MakeElementEqualityFunc(this, ToInt64Vector(vals), index, target);
1691}
1692
1693Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1694 IntVar* const index,
1695 IntVar* const target) {
1696 if (AreAllBound(vars)) {
1697 std::vector<int64_t> values(vars.size());
1698 for (int i = 0; i < vars.size(); ++i) {
1699 values[i] = vars[i]->Value();
1700 }
1701 return MakeElementEquality(values, index, target);
1702 }
1703 if (index->Bound()) {
1704 const int64_t val = index->Min();
1705 if (val < 0 || val >= vars.size()) {
1706 return MakeFalseConstraint();
1707 } else {
1708 return MakeEquality(target, vars[val]);
1709 }
1710 } else {
1711 if (target->Bound()) {
1712 return RevAlloc(
1713 new IntExprArrayElementCstCt(this, vars, index, target->Min()));
1714 } else {
1715 return RevAlloc(new IntExprArrayElementCt(this, vars, index, target));
1716 }
1717 }
1718}
1719
1720Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1721 IntVar* const index, int64_t target) {
1722 if (AreAllBound(vars)) {
1723 std::vector<int> valid_indices;
1724 for (int i = 0; i < vars.size(); ++i) {
1725 if (vars[i]->Value() == target) {
1726 valid_indices.push_back(i);
1727 }
1728 }
1729 return MakeMemberCt(index, valid_indices);
1730 }
1731 if (index->Bound()) {
1732 const int64_t pos = index->Min();
1733 if (pos >= 0 && pos < vars.size()) {
1734 IntVar* const var = vars[pos];
1735 return MakeEquality(var, target);
1736 } else {
1737 return MakeFalseConstraint();
1738 }
1739 } else {
1740 return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));
1741 }
1742}
1743
1744Constraint* Solver::MakeIndexOfConstraint(const std::vector<IntVar*>& vars,
1745 IntVar* const index, int64_t target) {
1746 if (index->Bound()) {
1747 const int64_t pos = index->Min();
1748 if (pos >= 0 && pos < vars.size()) {
1749 IntVar* const var = vars[pos];
1750 return MakeEquality(var, target);
1751 } else {
1752 return MakeFalseConstraint();
1753 }
1754 } else {
1755 return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));
1756 }
1757}
1758
1759IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
1760 int64_t value) {
1761 IntExpr* const cache = model_cache_->FindVarArrayConstantExpression(
1763 if (cache != nullptr) {
1764 return cache->Var();
1765 } else {
1766 const std::string name =
1767 absl::StrFormat("Index(%s, %d)", JoinNamePtr(vars, ", "), value);
1768 IntVar* const index = MakeIntVar(0, vars.size() - 1, name);
1769 AddConstraint(MakeIndexOfConstraint(vars, index, value));
1770 model_cache_->InsertVarArrayConstantExpression(
1772 return index;
1773 }
1774}
1775} // namespace operations_research
IntegerValue y
IntegerValue size
const std::vector< IntVar * > vars_
-------— Generalized element -------—
Definition element.cc:1126
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
Definition element.cc:1186
IfThenElseCt(Solver *const solver, IntVar *const condition, IntExpr *const one, IntExpr *const zero, IntVar *const target)
Definition element.cc:1128
std::string DebugString() const override
--------------— Constraint class ----------------—
Definition element.cc:1180
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetMax(int64_t m)=0
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual int64_t Min() const =0
virtual void SetMin(int64_t m)=0
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
virtual void WhenBound(Demon *d)=0
virtual void WhenDomain(Demon *d)=0
virtual int64_t OldMax() const =0
Returns the previous max.
IntVar * Var() override
Creates a variable from the expression.
virtual bool Contains(int64_t v) const =0
virtual void RemoveValue(int64_t v)=0
This method removes the value 'v' from the domain of the variable.
virtual int64_t OldMin() const =0
Returns the previous min.
virtual std::string name() const
Object naming.
For the time being, Solver is neither MT_SAFE nor MT_HOT.
IntExpr * MakeElement(const std::vector< int64_t > &values, IntVar *index)
values[index]
Definition element.cc:658
Constraint * MakeIfThenElseCt(IntVar *condition, IntExpr *then_expr, IntExpr *else_expr, IntVar *target_var)
Special cases with arrays of size two.
Definition element.cc:1609
IntExpr * MakeMonotonicElement(IndexEvaluator1 values, bool increasing, IntVar *index)
Definition element.cc:869
Constraint * MakeElementEquality(const std::vector< int64_t > &vals, IntVar *index, IntVar *target)
Definition element.cc:1681
Constraint * MakeIndexOfConstraint(const std::vector< IntVar * > &vars, IntVar *index, int64_t target)
Definition element.cc:1744
IntExpr * MakeIndexExpression(const std::vector< IntVar * > &vars, int64_t value)
Definition element.cc:1759
std::function< IntVar *(int64_t)> Int64ToIntVar
std::function< int64_t(int64_t)> IndexEvaluator1
Callback typedefs.
std::function< int64_t(int64_t, int64_t)> IndexEvaluator2
int64_t b
Definition table.cc:45
const std::string name
A name for logging purposes.
int64_t value
#define UPDATE_ELEMENT_INDEX_BOUNDS(test)
Definition element.cc:988
IntVar *const expr_
Definition element.cc:88
ABSL_FLAG(bool, cp_disable_element_cache, true, "If true, caching for IntElement is disabled.")
#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test)
Definition element.cc:424
IntVar * var
double upper_bound
double lower_bound
int index
std::pair< double, double > Range
A range of values, first is the minimum, second is the maximum.
Definition statistics.h:27
In SWIG mode, we don't want anything besides these top-level includes.
bool IsArrayConstant(const std::vector< T > &values, const T &value)
std::string JoinDebugStringPtr(const std::vector< T > &v, absl::string_view separator)
Join v[i]->DebugString().
bool IsIncreasing(const std::vector< T > &values)
bool IsArrayBoolean(const std::vector< T > &values)
Demon * MakeDelayedConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
void LinkVarExpr(Solver *s, IntExpr *expr, IntVar *var)
--— IntExprElement --—
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
--— Vector manipulations --—
Definition utilities.cc:829
bool IsIncreasingContiguous(const std::vector< T > &values)
Demon * MakeConstraintDemon1(Solver *const s, T *const ct, void(T::*method)(P), const std::string &name, P param1)
bool AreAllBound(const std::vector< IntVar * > &vars)
std::string JoinNamePtr(const std::vector< T > &v, absl::string_view separator)
Join v[i]->name().
STL namespace.
false true
Definition numbers.cc:228
const Variable x
Definition qp_tests.cc:127