24#include "absl/container/flat_hash_map.h"
25#include "absl/log/check.h"
26#include "absl/strings/str_format.h"
27#include "absl/strings/str_join.h"
28#include "absl/types/span.h"
39struct AffineTransformation {
40 AffineTransformation() :
a(1),
b(0) {}
41 AffineTransformation(int64_t aa, int64_t bb) :
a(aa),
b(bb) {
47 bool Reverse(int64_t
value, int64_t*
const reverse)
const {
48 const int64_t temp =
value -
b;
51 DCHECK_EQ(Forward(*reverse),
value);
58 int64_t Forward(int64_t
value)
const {
return value *
a +
b; }
60 int64_t UnsafeReverse(int64_t
value)
const {
return (
value - b) /
a; }
67 std::string DebugString()
const {
68 return absl::StrFormat(
"(%d * x + %d)", a, b);
73class VarLinearizer :
public ModelParser {
75 VarLinearizer() : target_var_(nullptr), transformation_(nullptr) {}
76 ~VarLinearizer()
override {}
78 void VisitIntegerVariable(
const IntVar*
const variable,
79 const std::string& operation, int64_t
value,
80 IntVar*
const delegate)
override {
83 delegate->Accept(
this);
87 delegate->Accept(
this);
90 PushMultiplier(
value);
91 delegate->Accept(
this);
94 *target_var_ =
const_cast<IntVar*
>(variable);
95 transformation_->a = multipliers_.back();
99 void VisitIntegerVariable(
const IntVar*
const variable,
100 IntExpr*
const delegate)
override {
101 *target_var_ =
const_cast<IntVar*
>(variable);
102 transformation_->a = multipliers_.back();
105 void Visit(
const IntVar*
const var, IntVar**
const target_var,
106 AffineTransformation*
const transformation) {
107 target_var_ = target_var;
108 transformation_ = transformation;
109 transformation->Clear();
113 CHECK(multipliers_.empty());
116 std::string DebugString()
const override {
return "VarLinearizer"; }
119 void AddConstant(int64_t constant) {
120 transformation_->b += constant * multipliers_.back();
123 void PushMultiplier(int64_t multiplier) {
124 if (multipliers_.empty()) {
125 multipliers_.push_back(multiplier);
127 multipliers_.push_back(multiplier * multipliers_.back());
131 void PopMultiplier() { multipliers_.pop_back(); }
133 std::vector<int64_t> multipliers_;
134 IntVar** target_var_;
135 AffineTransformation* transformation_;
138static const int kBitsInUint64 = 64;
154class BasePositiveTableConstraint :
public Constraint {
156 BasePositiveTableConstraint(Solver*
const s,
const std::vector<IntVar*>& vars,
157 const IntTupleSet& tuples)
165 transformations_(
arity_) {
175 VarLinearizer linearizer;
177 linearizer.Visit(vars[i], &vars_[i], &transformations_[i]);
181 holes_[
i] =
vars_[
i]->MakeHoleIterator(
true);
182 iterators_[
i] =
vars_[
i]->MakeDomainIterator(
true);
186 ~BasePositiveTableConstraint()
override {}
188 std::string DebugString()
const override {
189 return absl::StrFormat(
"AllowedAssignments(arity = %d, tuple_count = %d)",
190 arity_, tuple_count_);
193 void Accept(ModelVisitor*
const visitor)
const override {
202 bool TupleValue(
int tuple_index,
int var_index, int64_t*
const value)
const {
203 return transformations_[
var_index].Reverse(
207 int64_t UnsafeTupleValue(
int tuple_index,
int var_index)
const {
208 return transformations_[
var_index].UnsafeReverse(
212 bool IsTupleSupported(
int tuple_index) {
226 std::vector<IntVarIterator*> holes_;
227 std::vector<IntVarIterator*> iterators_;
228 std::vector<int64_t> to_remove_;
232 const IntTupleSet tuples_;
235 std::vector<AffineTransformation> transformations_;
238class PositiveTableConstraint :
public BasePositiveTableConstraint {
240 typedef absl::flat_hash_map<int, std::vector<uint64_t>> ValueBitset;
242 PositiveTableConstraint(Solver*
const s,
const std::vector<IntVar*>& vars,
243 const IntTupleSet& tuples)
244 : BasePositiveTableConstraint(s, vars, tuples),
248 ~PositiveTableConstraint()
override {}
250 void Post()
override {
252 solver(),
this, &PositiveTableConstraint::Propagate,
"Propagate");
256 solver(),
this, &PositiveTableConstraint::Update,
"Update", i);
266 std::vector<uint64_t> actives(word_length_, 0);
267 for (
int tuple_index = 0; tuple_index <
tuple_count_; ++tuple_index) {
268 if (IsTupleSupported(tuple_index)) {
269 SetBit64(actives.data(), tuple_index);
275 void InitialPropagate()
override {
278 for (
const auto& it : masks_[
var_index]) {
292 for (
const int64_t
value : InitAndGetValues(iterators_[
var_index])) {
293 if (!mask.contains(
value)) {
294 to_remove_.push_back(
value);
297 if (!to_remove_.empty()) {
298 var->RemoveValues(to_remove_);
307 for (
const int64_t
value : InitAndGetValues(iterators_[
var_index])) {
309 to_remove_.push_back(
value);
312 if (!to_remove_.empty()) {
313 var->RemoveValues(to_remove_);
318 void Update(
int index) {
321 const int64_t old_max =
var->OldMax();
322 const int64_t vmin =
var->Min();
323 const int64_t vmax =
var->Max();
325 const auto& it = var_masks.find(
value);
326 if (it != var_masks.end()) {
327 BlankActives(it->second);
330 for (
const int64_t
value : InitAndGetValues(holes_[
index])) {
331 const auto& it = var_masks.find(
value);
332 if (it != var_masks.end()) {
333 BlankActives(it->second);
337 const auto& it = var_masks.find(
value);
338 if (it != var_masks.end()) {
339 BlankActives(it->second);
344 void BlankActives(
const std::vector<uint64_t>& mask) {
362 std::string DebugString()
const override {
363 return absl::StrFormat(
"PositiveTableConstraint([%s], %d tuples)",
368 void InitializeMask(
int tuple_index) {
369 std::vector<int64_t> cache(
arity_);
379 mask.assign(word_length_, 0);
393class CompactPositiveTableConstraint :
public BasePositiveTableConstraint {
395 CompactPositiveTableConstraint(Solver*
const s,
396 const std::vector<IntVar*>& vars,
397 const IntTupleSet& tuples)
398 : BasePositiveTableConstraint(s, vars, tuples),
411 ~CompactPositiveTableConstraint()
override {}
413 void Post()
override {
415 solver(),
this, &CompactPositiveTableConstraint::Propagate,
419 solver(),
this, &CompactPositiveTableConstraint::Update,
"Update", i);
423 var_sizes_.SetValue(solver(), i,
vars_[i]->Size());
427 void InitialPropagate()
override {
429 FillMasksAndActiveTuples();
430 ComputeMasksBoundaries();
432 RemoveUnsupportedValues();
439 if (touched_var_ == -2) {
456 const int64_t original_min = original_min_[
var_index];
457 const int64_t var_size =
var->Size();
468 const int64_t var_min =
var->Min();
469 const int64_t var_max =
var->Max();
470 const bool min_support = Supported(
var_index, var_min - original_min);
471 const bool max_support = Supported(
var_index, var_max - original_min);
476 var->SetValue(var_max);
477 var_sizes_.SetValue(solver(),
var_index, 1);
479 }
else if (!max_support) {
480 var->SetValue(var_min);
481 var_sizes_.SetValue(solver(),
var_index, 1);
487 const int64_t var_min =
var->Min();
488 const int64_t var_max =
var->Max();
489 int64_t new_min = var_min;
490 int64_t new_max = var_max;
494 if (var_max - var_min + 1 == var_size) {
495 for (; new_min <= var_max; ++new_min) {
496 if (Supported(
var_index, new_min - original_min)) {
500 for (; new_max >= new_min; --new_max) {
501 if (Supported(
var_index, new_max - original_min)) {
505 var->SetRange(new_min, new_max);
508 to_remove_.push_back(
value);
515 new_min = std::numeric_limits<int64_t>::max();
516 for (
const int64_t
value :
517 InitAndGetValues(iterators_[
var_index])) {
519 to_remove_.push_back(
value);
521 if (new_min == std::numeric_limits<int64_t>::max()) {
529 var->SetRange(new_min, new_max);
531 int index = to_remove_.size() - 1;
532 while (
index >= 0 && to_remove_[
index] > new_max) {
535 to_remove_.resize(
index + 1);
537 var->RemoveValues(to_remove_);
553 bool changed =
false;
554 const int64_t omin = original_min_[
var_index];
555 const int64_t var_size =
var->Size();
556 const int64_t var_min =
var->Min();
557 const int64_t var_max =
var->Max();
561 changed = AndMaskWithActive(masks_[
var_index][var_min - omin]);
567 changed = AndMaskWithActive(temp_mask_);
571 const int64_t estimated_hole_size =
573 const int64_t old_min =
var->OldMin();
574 const int64_t old_max =
var->OldMax();
577 const int64_t number_of_operations =
578 estimated_hole_size + var_min - old_min + old_max - var_max;
579 if (number_of_operations < var_size) {
582 changed |= SubtractMaskFromActive(masks_[
var_index][
value - omin]);
585 changed |= SubtractMaskFromActive(masks_[
var_index][
value - omin]);
588 changed |= SubtractMaskFromActive(masks_[
var_index][
value - omin]);
594 if (var_max - var_min + 1 == var_size) {
599 for (
const int64_t
value :
600 InitAndGetValues(iterators_[
var_index])) {
605 changed = AndMaskWithActive(temp_mask_);
609 var_sizes_.SetValue(solver(),
var_index, var_size);
614 if (touched_var_ == -1 || touched_var_ ==
var_index) {
619 EnqueueDelayedDemon(demon_);
623 std::string DebugString()
const override {
624 return absl::StrFormat(
"CompactPositiveTableConstraint([%s], %d tuples)",
634 original_min_[
i] =
vars_[
i]->Min();
635 const int64_t span =
vars_[
i]->Max() - original_min_[
i] + 1;
640 void FillMasksAndActiveTuples() {
641 std::vector<uint64_t> actives(word_length_, 0);
642 for (
int tuple_index = 0; tuple_index <
tuple_count_; ++tuple_index) {
643 if (IsTupleSupported(tuple_index)) {
644 SetBit64(actives.data(), tuple_index);
649 DCHECK_GE(value_index, 0);
651 if (masks_[
var_index][value_index].empty()) {
661 void RemoveUnsupportedValues() {
666 for (
const int64_t
value : InitAndGetValues(iterators_[
var_index])) {
668 to_remove_.push_back(
value);
671 if (!to_remove_.empty()) {
672 var->RemoveValues(to_remove_);
677 void ComputeMasksBoundaries() {
688 while (
start < word_length_ && mask[
start] == 0) {
691 DCHECK_LT(
start, word_length_);
697 DCHECK_NE(mask[
start], 0);
698 DCHECK_NE(mask[
end], 0);
705 void BuildSupports() {
713 bool AndMaskWithActive(
const std::vector<uint64_t>& mask) {
721 bool SubtractMaskFromActive(
const std::vector<uint64_t>& mask) {
729 bool Supported(
int var_index, int64_t value_index) {
732 DCHECK_GE(value_index, 0);
735 DCHECK(!mask.empty());
739 void OrTempMask(
int var_index, int64_t value_index) {
742 const int mask_span = mask_ends_[
var_index][value_index] -
743 mask_starts_[
var_index][value_index] + 1;
749 for (
int i = mask_starts_[
var_index][value_index];
757 void SetTempMask(
int var_index, int64_t value_index) {
772 void ClearTempMask() {
788 std::vector<std::vector<std::vector<uint64_t>>>
masks_;
790 std::vector<std::vector<int>> mask_starts_;
791 std::vector<std::vector<int>> mask_ends_;
793 std::vector<int64_t> original_min_;
801 RevArray<int64_t> var_sizes_;
808class SmallCompactPositiveTableConstraint :
public BasePositiveTableConstraint {
810 SmallCompactPositiveTableConstraint(Solver*
const s,
811 const std::vector<IntVar*>& vars,
812 const IntTupleSet& tuples)
813 : BasePositiveTableConstraint(s, vars, tuples),
822 CHECK_LE(tuples.NumTuples(), kBitsInUint64);
825 ~SmallCompactPositiveTableConstraint()
override {}
827 void Post()
override {
829 solver(),
this, &SmallCompactPositiveTableConstraint::Propagate,
832 if (!
vars_[i]->Bound()) {
834 solver(),
this, &SmallCompactPositiveTableConstraint::Update,
836 vars_[
i]->WhenDomain(update_demon);
845 original_min_[
i] =
vars_[
i]->Min();
846 const int64_t span =
vars_[
i]->Max() - original_min_[
i] + 1;
851 bool IsTupleSupported(
int tuple_index) {
862 void ComputeActiveTuples() {
865 for (
int tuple_index = 0; tuple_index <
tuple_count_; ++tuple_index) {
866 if (IsTupleSupported(tuple_index)) {
867 const uint64_t local_mask =
OneBit64(tuple_index);
875 if (!active_tuples_) {
880 void RemoveUnsupportedValues() {
884 const int64_t original_min = original_min_[
var_index];
886 for (
const int64_t
value : InitAndGetValues(iterators_[
var_index])) {
888 to_remove_.push_back(
value);
891 if (!to_remove_.empty()) {
892 var->RemoveValues(to_remove_);
897 void InitialPropagate()
override {
899 ComputeActiveTuples();
900 RemoveUnsupportedValues();
910 if (touched_var_ == -2) {
928 const int64_t original_min = original_min_[
var_index];
930 const int64_t var_size =
var->Size();
933 if ((var_mask[
var->Min() - original_min] & actives) == 0) {
943 const int64_t var_min =
var->Min();
944 const int64_t var_max =
var->Max();
945 const bool min_support =
946 (var_mask[var_min - original_min] & actives) != 0;
947 const bool max_support =
948 (var_mask[var_max - original_min] & actives) != 0;
949 if (!min_support && !max_support) {
951 }
else if (!min_support) {
952 var->SetValue(var_max);
953 }
else if (!max_support) {
954 var->SetValue(var_min);
960 const int64_t var_min =
var->Min();
961 const int64_t var_max =
var->Max();
962 int64_t new_min = var_min;
963 int64_t new_max = var_max;
964 if (var_max - var_min + 1 == var_size) {
966 for (; new_min <= var_max; ++new_min) {
967 if ((var_mask[new_min - original_min] & actives) != 0) {
971 for (; new_max >= new_min; --new_max) {
972 if ((var_mask[new_max - original_min] & actives) != 0) {
976 var->SetRange(new_min, new_max);
978 if ((var_mask[
value - original_min] & actives) == 0) {
979 to_remove_.push_back(
value);
983 bool min_set =
false;
985 for (
const int64_t
value :
986 InitAndGetValues(iterators_[
var_index])) {
989 if ((var_mask[
value - original_min] & actives) == 0) {
991 to_remove_.push_back(
value);
999 last_size = to_remove_.size();
1003 var->SetRange(new_min, new_max);
1007 to_remove_.resize(last_size);
1009 var->RemoveValues(to_remove_);
1020 const int64_t original_min = original_min_[
var_index];
1021 const int64_t var_size =
var->Size();
1036 const int64_t old_min =
var->OldMin();
1037 const int64_t old_max =
var->OldMax();
1038 const int64_t var_min =
var->Min();
1039 const int64_t var_max =
var->Max();
1040 const bool contiguous = var_size == var_max - var_min + 1;
1041 const bool nearly_contiguous =
1042 var_size > (var_max - var_min + 1) * 7 / 10;
1049 uint64_t hole_mask = 0;
1051 for (
const int64_t
value : InitAndGetValues(holes_[
var_index])) {
1052 hole_mask |= var_mask[
value - original_min];
1055 const int64_t hole_operations = var_min - old_min + old_max - var_max;
1057 const int64_t domain_operations = contiguous ? var_size : 4 * var_size;
1058 if (hole_operations < domain_operations) {
1060 hole_mask |= var_mask[
value - original_min];
1063 hole_mask |= var_mask[
value - original_min];
1068 uint64_t domain_mask = 0;
1071 domain_mask |= var_mask[
value - original_min];
1073 }
else if (nearly_contiguous) {
1076 domain_mask |= var_mask[
value - original_min];
1080 for (
const int64_t
value :
1081 InitAndGetValues(iterators_[
var_index])) {
1082 domain_mask |= var_mask[
value - original_min];
1091 std::string DebugString()
const override {
1092 return absl::StrFormat(
1093 "SmallCompactPositiveTableConstraint([%s], %d tuples)",
1098 void ApplyMask(
int var_index, uint64_t mask) {
1099 if ((~mask & active_tuples_) != 0) {
1101 const uint64_t current_stamp = solver()->stamp();
1102 if (stamp_ < current_stamp) {
1103 stamp_ = current_stamp;
1104 solver()->SaveValue(&active_tuples_);
1107 if (active_tuples_) {
1109 if (touched_var_ == -1 || touched_var_ ==
var_index) {
1114 EnqueueDelayedDemon(demon_);
1128 std::vector<std::vector<uint64_t>>
masks_;
1130 std::vector<int64_t> original_min_;
1135bool HasCompactDomains(
const std::vector<IntVar*>& vars) {
1146class TransitionConstraint :
public Constraint {
1151 TransitionConstraint(Solver*
const s,
const std::vector<IntVar*>& vars,
1152 const IntTupleSet& transition_table,
1153 int64_t initial_state,
1154 const std::vector<int64_t>& final_states)
1157 transition_table_(transition_table),
1158 initial_state_(initial_state),
1159 final_states_(final_states) {}
1161 TransitionConstraint(Solver*
const s,
const std::vector<IntVar*>& vars,
1162 const IntTupleSet& transition_table,
1163 int64_t initial_state,
1164 absl::Span<const int> final_states)
1167 transition_table_(transition_table),
1168 initial_state_(initial_state),
1169 final_states_(final_states.
size()) {
1170 for (
int i = 0;
i < final_states.size(); ++
i) {
1171 final_states_[
i] = final_states[
i];
1175 ~TransitionConstraint()
override {}
1177 void Post()
override {
1178 Solver*
const s = solver();
1179 int64_t state_min = std::numeric_limits<int64_t>::max();
1180 int64_t state_max = std::numeric_limits<int64_t>::min();
1181 const int nb_vars =
vars_.size();
1182 for (
int i = 0;
i < transition_table_.NumTuples(); ++
i) {
1184 std::max(state_max, transition_table_.Value(i, kStatePosition));
1186 std::max(state_max, transition_table_.Value(i, kNextStatePosition));
1188 std::min(state_min, transition_table_.Value(i, kStatePosition));
1190 std::min(state_min, transition_table_.Value(i, kNextStatePosition));
1193 std::vector<IntVar*> states;
1194 states.push_back(s->MakeIntConst(initial_state_));
1196 states.push_back(s->MakeIntVar(state_min, state_max));
1198 states.push_back(s->MakeIntVar(final_states_));
1199 CHECK_EQ(nb_vars + 1, states.size());
1201 const int num_tuples = transition_table_.NumTuples();
1204 std::vector<IntVar*> tmp_vars(3);
1209 if (num_tuples <= kBitsInUint64) {
1210 s->AddConstraint(s->RevAlloc(
new SmallCompactPositiveTableConstraint(
1211 s, tmp_vars, transition_table_)));
1213 s->AddConstraint(s->RevAlloc(
new CompactPositiveTableConstraint(
1214 s, tmp_vars, transition_table_)));
1219 void InitialPropagate()
override {}
1221 void Accept(ModelVisitor*
const visitor)
const override {
1233 std::string DebugString()
const override {
1234 return absl::StrFormat(
1235 "TransitionConstraint([%s], %d transitions, initial = %d, final = "
1238 initial_state_, absl::StrJoin(final_states_,
", "));
1243 const std::vector<IntVar*>
vars_;
1245 const IntTupleSet transition_table_;
1247 const int64_t initial_state_;
1249 std::vector<int64_t> final_states_;
1252const int TransitionConstraint::kStatePosition = 0;
1253const int TransitionConstraint::kNextStatePosition = 2;
1254const int TransitionConstraint::kTransitionTupleSize = 3;
1261 if (HasCompactDomains(vars)) {
1262 if (tuples.
NumTuples() < kBitsInUint64 && parameters_.use_small_table()) {
1264 new SmallCompactPositiveTableConstraint(
this, vars, tuples));
1266 return RevAlloc(
new CompactPositiveTableConstraint(
this, vars, tuples));
1269 return RevAlloc(
new PositiveTableConstraint(
this, vars, tuples));
1273 const std::vector<IntVar*>& vars,
const IntTupleSet& transition_table,
1274 int64_t initial_state,
const std::vector<int64_t>& final_states) {
1275 return RevAlloc(
new TransitionConstraint(
this, vars, transition_table,
1276 initial_state, final_states));
1280 const std::vector<IntVar*>& vars,
const IntTupleSet& transition_table,
1281 int64_t initial_state,
const std::vector<int>& final_states) {
1282 return RevAlloc(
new TransitionConstraint(
this, vars, transition_table,
1283 initial_state, final_states));
const std::vector< IntVar * > vars_
--— Main IntTupleSet class --—
int NumTuples() const
Returns the number of tuples.
static const char kDifferenceOperation[]
static const char kInitialState[]
static const char kTraceOperation[]
static const char kSumOperation[]
static const char kFinalStatesArgument[]
static const char kTuplesArgument[]
static const char kTransition[]
static const char kProductOperation[]
static const char kVarsArgument[]
static const char kAllowedAssignments[]
Constraint * MakeAllowedAssignments(const std::vector< IntVar * > &vars, const IntTupleSet &tuples)
------— API -------—
Constraint * MakeTransitionConstraint(const std::vector< IntVar * > &vars, const IntTupleSet &transition_table, int64_t initial_state, const std::vector< int64_t > &final_states)
static const int kTransitionTupleSize
const int word_length_
The length in 64 bit words of the number of tuples.
static const int kStatePosition
static const int kNextStatePosition
std::vector< ValueBitset > masks_
The masks per value per variable.
std::vector< uint64_t > temp_mask_
A temporary mask use for computation.
UnsortedNullableRevBitset active_tuples_
The active bitset.
std::vector< int > supports_
In SWIG mode, we don't want anything besides these top-level includes.
std::string JoinDebugStringPtr(const std::vector< T > &v, absl::string_view separator)
Join v[i]->DebugString().
Demon * MakeDelayedConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
uint64_t OneBit64(int pos)
Returns a word with only bit pos set.
uint64_t BitLength64(uint64_t size)
Returns the number of words needed to store size bits.
Demon * MakeConstraintDemon1(Solver *const s, T *const ct, void(T::*method)(P), const std::string &name, P param1)
void SetBit64(uint64_t *const bitset, uint64_t pos)
Sets the bit pos to true in bitset.
std::optional< int64_t > end