26#include "absl/container/btree_map.h"
27#include "absl/container/flat_hash_map.h"
28#include "absl/container/flat_hash_set.h"
29#include "absl/log/check.h"
30#include "absl/meta/type_traits.h"
31#include "absl/types/span.h"
41#include "ortools/sat/sat_parameters.pb.h"
53 if (!VLOG_IS_ON(1))
return;
54 if (shared_stats_ ==
nullptr)
return;
55 std::vector<std::pair<std::string, int64_t>> stats;
56 stats.push_back({
"implied_bound/num_deductions", num_deductions_});
57 stats.push_back({
"implied_bound/num_stored", bounds_.size()});
59 {
"implied_bound/num_stored_with_view", num_enqueued_in_var_to_bounds_});
60 shared_stats_->AddStats(stats);
64 if (!parameters_.use_implied_bounds())
return true;
65 const IntegerVariable var = integer_literal.
var;
69 const IntegerValue root_lb = integer_trail_->LevelZeroLowerBound(var);
70 if (integer_literal.
bound <= root_lb)
return true;
76 if (root_lb + 1 >= integer_trail_->LevelZeroUpperBound(var))
return true;
79 const auto key = std::make_pair(literal.
Index(), var);
80 auto insert_result = bounds_.insert({key, integer_literal.
bound});
81 if (!insert_result.second) {
82 if (insert_result.first->second < integer_literal.
bound) {
83 insert_result.first->second = integer_literal.
bound;
91 if (integer_trail_->LevelZeroUpperBound(var) == integer_literal.
bound) {
96 if (it != bounds_.end() && it->second == -integer_literal.
bound) {
108 const auto it = bounds_.find(std::make_pair(literal.
NegatedIndex(), var));
109 if (it != bounds_.end()) {
110 if (it->second <= root_lb) {
115 const IntegerValue deduction =
116 std::min(integer_literal.
bound, it->second);
117 DCHECK_GT(deduction, root_lb);
120 if (!integer_trail_->RootLevelEnqueue(
125 VLOG(2) <<
"Deduction old: "
127 var, integer_trail_->LevelZeroLowerBound(var))
132 if (it->second == deduction) {
135 if (integer_literal.
bound == deduction) {
136 bounds_.erase(std::make_pair(literal.
Index(), var));
146 if (parameters_.linearization_level() == 0)
return true;
147 if (parameters_.cut_level() == 0)
return true;
153 if (var_to_bounds_.size() <= var) {
154 var_to_bounds_.resize(var.value() + 1);
155 has_implied_bounds_.Resize(var + 1);
157 ++num_enqueued_in_var_to_bounds_;
158 has_implied_bounds_.Set(var);
159 var_to_bounds_[var].push_back({integer_encoder_->GetLiteralView(literal),
160 integer_literal.
bound,
true});
161 }
else if (integer_encoder_->GetLiteralView(literal.
Negated()) !=
163 if (var_to_bounds_.size() <= var) {
164 var_to_bounds_.resize(var.value() + 1);
165 has_implied_bounds_.Resize(var + 1);
167 ++num_enqueued_in_var_to_bounds_;
168 has_implied_bounds_.Set(var);
169 var_to_bounds_[var].push_back(
170 {integer_encoder_->GetLiteralView(literal.
Negated()),
171 integer_literal.
bound,
false});
177 IntegerVariable var) {
178 if (var >= var_to_bounds_.size())
return empty_implied_bounds_;
185 std::vector<ImpliedBoundEntry>& ref = var_to_bounds_[var];
186 const IntegerValue root_lb = integer_trail_->LevelZeroLowerBound(var);
188 if (entry.lower_bound <= root_lb)
continue;
189 ref[new_size++] = entry;
191 ref.resize(new_size);
198 IntegerValue value) {
203 literal_to_var_to_value_[literal.
Index()][var] = value;
207 if (!parameters_.use_implied_bounds())
return true;
209 CHECK_EQ(sat_solver_->CurrentDecisionLevel(), 1);
210 tmp_integer_literals_.clear();
211 integer_trail_->AppendNewBounds(&tmp_integer_literals_);
213 if (!
Add(first_decision, lit))
return false;
219 const std::vector<ValueLiteralPair>& encoding,
220 int exactly_one_index) {
221 if (!var_to_index_to_element_encodings_.contains(var)) {
222 element_encoded_variables_.push_back(var);
224 var_to_index_to_element_encodings_[var][exactly_one_index] = encoding;
227const absl::btree_map<int, std::vector<ValueLiteralPair>>&
229 const auto& it = var_to_index_to_element_encodings_.find(var);
230 if (it == var_to_index_to_element_encodings_.end()) {
231 return empty_element_encoding_;
237const std::vector<IntegerVariable>&
239 return element_encoded_variables_;
259 absl::Span<const ValueLiteralPair> affine_var_encoding,
260 bool put_affine_left_in_result,
IntegerEncoder* integer_encoder) {
261 IntegerVariable binary = size2_affine.
var;
262 std::vector<LiteralValueValue> terms;
264 const std::vector<ValueLiteralPair>& size2_enc =
269 if (size2_enc.size() != 2)
return terms;
271 Literal lit0 = size2_enc[0].literal;
272 IntegerValue value0 = size2_affine.
ValueAt(size2_enc[0].value);
273 Literal lit1 = size2_enc[1].literal;
274 IntegerValue value1 = size2_affine.
ValueAt(size2_enc[1].value);
276 for (
const auto& [unused, candidate_literal] : affine_var_encoding) {
277 if (candidate_literal == lit1) {
278 std::swap(lit0, lit1);
279 std::swap(value0, value1);
281 if (candidate_literal != lit0)
continue;
284 for (
const auto& [value, literal] : affine_var_encoding) {
285 const IntegerValue size_2_value = literal == lit0 ? value0 : value1;
286 const IntegerValue affine_value = affine.
ValueAt(value);
287 if (put_affine_left_in_result) {
288 terms.push_back({literal, affine_value, size_2_value});
290 terms.push_back({literal, size_2_value, affine_value});
304 std::vector<LiteralValueValue> terms;
309 const std::vector<ValueLiteralPair> left_enc =
311 const std::vector<ValueLiteralPair> right_enc =
313 if (left_enc.size() != 2 || right_enc.size() != 2) {
314 VLOG(2) <<
"encodings are not fully propagated";
318 const Literal left_lit0 = left_enc[0].literal;
319 const IntegerValue left_value0 = left.
ValueAt(left_enc[0].value);
320 const Literal left_lit1 = left_enc[1].literal;
321 const IntegerValue left_value1 = left.
ValueAt(left_enc[1].value);
323 const Literal right_lit0 = right_enc[0].literal;
324 const IntegerValue right_value0 = right.
ValueAt(right_enc[0].value);
325 const Literal right_lit1 = right_enc[1].literal;
326 const IntegerValue right_value1 = right.
ValueAt(right_enc[1].value);
328 if (left_lit0 == right_lit0 || left_lit0 == right_lit1.
Negated()) {
329 terms.push_back({left_lit0, left_value0, right_value0});
330 terms.push_back({left_lit0.
Negated(), left_value1, right_value1});
331 }
else if (left_lit0 == right_lit1 || left_lit0 == right_lit0.
Negated()) {
332 terms.push_back({left_lit0, left_value0, right_value1});
333 terms.push_back({left_lit0.
Negated(), left_value1, right_value0});
334 }
else if (left_lit1 == right_lit1 || left_lit1 == right_lit0.
Negated()) {
335 terms.push_back({left_lit1.
Negated(), left_value0, right_value0});
336 terms.push_back({left_lit1, left_value1, right_value1});
337 }
else if (left_lit1 == right_lit0 || left_lit1 == right_lit1.
Negated()) {
338 terms.push_back({left_lit1.
Negated(), left_value0, right_value1});
339 terms.push_back({left_lit1, left_value1, right_value0});
341 VLOG(3) <<
"Complex size 2 encoding case, need to scan exactly_ones";
349 if (integer_trail_->IsFixed(left) || integer_trail_->IsFixed(right)) {
354 const absl::btree_map<int, std::vector<ValueLiteralPair>>& left_encodings =
355 element_encodings_->Get(left.
var);
358 const absl::btree_map<int, std::vector<ValueLiteralPair>>& right_encodings =
359 element_encodings_->Get(right.
var);
361 std::vector<int> compatible_keys;
362 for (
const auto& [index, encoding] : left_encodings) {
363 if (right_encodings.contains(index)) {
364 compatible_keys.push_back(index);
368 if (compatible_keys.empty()) {
369 if (integer_trail_->InitialVariableDomain(left.
var).Size() == 2) {
370 for (
const auto& [index, right_encoding] : right_encodings) {
372 left, right, right_encoding,
373 false, integer_encoder_);
374 if (!result.empty()) {
379 if (integer_trail_->InitialVariableDomain(right.
var).Size() == 2) {
380 for (
const auto& [index, left_encoding] : left_encodings) {
382 right, left, left_encoding,
383 true, integer_encoder_);
384 if (!result.empty()) {
389 if (integer_trail_->InitialVariableDomain(left.
var).Size() == 2 &&
390 integer_trail_->InitialVariableDomain(right.
var).Size() == 2) {
391 const std::vector<LiteralValueValue> result =
393 if (!result.empty()) {
400 if (compatible_keys.size() > 1) {
401 VLOG(3) <<
"More than one exactly_one involved in the encoding of the two "
406 const int min_index =
407 *std::min_element(compatible_keys.begin(), compatible_keys.end());
410 const std::vector<ValueLiteralPair>& left_encoding =
411 left_encodings.at(min_index);
412 const std::vector<ValueLiteralPair>& right_encoding =
413 right_encodings.at(min_index);
414 DCHECK_EQ(left_encoding.size(), right_encoding.size());
417 std::vector<LiteralValueValue> terms;
418 for (
int i = 0;
i < left_encoding.size(); ++
i) {
419 const Literal literal = left_encoding[
i].literal;
420 DCHECK_EQ(literal, right_encoding[
i].literal);
421 terms.push_back({literal, left.
ValueAt(left_encoding[
i].value),
422 right.
ValueAt(right_encoding[
i].value)});
433 DCHECK(builder !=
nullptr);
436 if (integer_trail_->IsFixed(left)) {
437 if (integer_trail_->IsFixed(right)) {
438 builder->
AddConstant(integer_trail_->FixedValue(left) *
439 integer_trail_->FixedValue(right));
442 builder->
AddTerm(right, integer_trail_->FixedValue(left));
446 if (integer_trail_->IsFixed(right)) {
447 builder->
AddTerm(left, integer_trail_->FixedValue(right));
456 const IntegerValue left_coeff =
458 const IntegerValue right_coeff =
461 left_coeff * right_coeff + left.
constant * right_coeff +
467 const std::vector<LiteralValueValue> decomposition =
469 if (decomposition.empty())
return false;
474 std::min(min_coefficient, term.left_value * term.right_value);
477 const IntegerValue coefficient =
478 term.left_value * term.right_value - min_coefficient;
479 if (coefficient == 0)
continue;
490 model->GetOrCreate<SatParameters>()->detect_linearized_product() &&
491 model->GetOrCreate<SatParameters>()->linearization_level() > 1),
492 rlt_enabled_(model->GetOrCreate<SatParameters>()->add_rlt_cuts() &&
493 model->GetOrCreate<SatParameters>()->linearization_level() >
495 sat_solver_(model->GetOrCreate<
SatSolver>()),
496 trail_(model->GetOrCreate<
Trail>()),
502 if (!VLOG_IS_ON(1))
return;
503 if (shared_stats_ ==
nullptr)
return;
504 std::vector<std::pair<std::string, int64_t>> stats;
506 {
"product_detector/num_processed_binary", num_processed_binary_});
508 {
"product_detector/num_processed_exactly_one", num_processed_exo_});
510 {
"product_detector/num_processed_ternary", num_processed_ternary_});
511 stats.push_back({
"product_detector/num_trail_updates", num_trail_updates_});
512 stats.push_back({
"product_detector/num_products", num_products_});
513 stats.push_back({
"product_detector/num_conditional_equalities",
514 num_conditional_equalities_});
516 {
"product_detector/num_conditional_zeros", num_conditional_zeros_});
517 stats.push_back({
"product_detector/num_int_products", num_int_products_});
518 shared_stats_->AddStats(stats);
522 absl::Span<const Literal> ternary_clause) {
523 if (ternary_clause.size() != 3)
return;
524 ++num_processed_ternary_;
526 if (rlt_enabled_) ProcessTernaryClauseForRLT(ternary_clause);
527 if (!enabled_)
return;
529 candidates_[GetKey(ternary_clause[0].Index(), ternary_clause[1].Index())]
530 .push_back(ternary_clause[2].Index());
531 candidates_[GetKey(ternary_clause[0].Index(), ternary_clause[2].Index())]
532 .push_back(ternary_clause[1].Index());
533 candidates_[GetKey(ternary_clause[1].Index(), ternary_clause[2].Index())]
534 .push_back(ternary_clause[0].Index());
538 for (
const Literal l : ternary_clause) {
539 if (l.Index() >= seen_.size()) seen_.resize(l.Index() + 1);
540 seen_[l.Index()] =
true;
545void ProductDetector::ProcessTernaryClauseForRLT(
546 absl::Span<const Literal> ternary_clause) {
547 const int old_size = ternary_clauses_with_view_.size();
548 for (
const Literal l : ternary_clause) {
549 const IntegerVariable var =
552 ternary_clauses_with_view_.resize(old_size);
555 ternary_clauses_with_view_.push_back(l.IsPositive() ? var
561 absl::Span<const Literal> ternary_exo) {
562 if (ternary_exo.size() != 3)
return;
563 ++num_processed_exo_;
565 if (rlt_enabled_) ProcessTernaryClauseForRLT(ternary_exo);
566 if (!enabled_)
return;
568 ProcessNewProduct(ternary_exo[0].Index(), ternary_exo[1].NegatedIndex(),
569 ternary_exo[2].NegatedIndex());
570 ProcessNewProduct(ternary_exo[1].Index(), ternary_exo[0].NegatedIndex(),
571 ternary_exo[2].NegatedIndex());
572 ProcessNewProduct(ternary_exo[2].Index(), ternary_exo[0].NegatedIndex(),
573 ternary_exo[1].NegatedIndex());
579 absl::Span<const Literal> binary_clause) {
580 if (!enabled_)
return;
581 if (binary_clause.size() != 2)
return;
582 ++num_processed_binary_;
583 const std::array<LiteralIndex, 2> key =
584 GetKey(binary_clause[0].NegatedIndex(), binary_clause[1].NegatedIndex());
585 std::array<LiteralIndex, 3> ternary;
586 for (
const LiteralIndex l : candidates_[key]) {
590 std::sort(ternary.begin(), ternary.end());
591 const int l_index = ternary[0] == l ? 0 : ternary[1] == l ? 1 : 2;
592 std::bitset<3>& bs = detector_[ternary];
593 if (bs[l_index])
continue;
595 if (bs[0] && bs[1] && l_index != 2) {
596 ProcessNewProduct(ternary[2],
Literal(ternary[0]).NegatedIndex(),
597 Literal(ternary[1]).NegatedIndex());
599 if (bs[0] && bs[2] && l_index != 1) {
600 ProcessNewProduct(ternary[1],
Literal(ternary[0]).NegatedIndex(),
601 Literal(ternary[2]).NegatedIndex());
603 if (bs[1] && bs[2] && l_index != 0) {
604 ProcessNewProduct(ternary[0],
Literal(ternary[1]).NegatedIndex(),
605 Literal(ternary[2]).NegatedIndex());
611 if (!enabled_)
return;
612 for (LiteralIndex a(0); a < seen_.size(); ++a) {
613 if (!seen_[a])
continue;
614 if (trail_->Assignment().LiteralIsAssigned(
Literal(a)))
continue;
623 if (!enabled_)
return;
624 if (trail_->CurrentDecisionLevel() != 1)
return;
625 ++num_trail_updates_;
633 const int current_index = trail_->
Index();
641 const auto it = products_.find(GetKey(a.
Index(),
b.Index()));
646std::array<LiteralIndex, 2> ProductDetector::GetKey(LiteralIndex a,
647 LiteralIndex
b)
const {
648 std::array<LiteralIndex, 2> key{a,
b};
649 if (key[0] > key[1]) std::swap(key[0], key[1]);
653void ProductDetector::ProcessNewProduct(LiteralIndex p, LiteralIndex a,
657 products_[GetKey(a,
b)] = p;
661 GetKey(Literal(a).IsPositive() ? a : Literal(a).NegatedIndex(),
662 Literal(
b).IsPositive() ?
b : Literal(
b).NegatedIndex()));
666 IntegerVariable
b)
const {
667 if (a ==
b)
return true;
672 if (integer_trail_->InitialVariableDomain(a).Size() != 2)
return false;
673 if (integer_trail_->InitialVariableDomain(
b).Size() != 2)
return false;
675 const LiteralIndex la =
677 a, integer_trail_->LevelZeroUpperBound(a)));
680 const LiteralIndex lb =
682 b, integer_trail_->LevelZeroUpperBound(
b)));
686 return has_product_.contains(
687 GetKey(
Literal(la).IsPositive() ? la :
Literal(la).NegatedIndex(),
692 IntegerVariable
b)
const {
698void ProductDetector::ProcessNewProduct(IntegerVariable p,
Literal l,
707 int_products_[{l.
Index(), x}] = p;
712 ++num_conditional_equalities_;
716 for (
int i = 0;
i < 2; ++
i) {
725 std::vector<IntegerVariable>& others =
726 conditional_equalities_[{l.
Index(), x}];
727 for (
const IntegerVariable o : others) {
736 if (conditional_zeros_.contains({l.NegatedIndex(), x})) {
737 ProcessNewProduct(x, l, y);
745 ++num_conditional_zeros_;
747 auto [_, inserted] = conditional_zeros_.insert({l.
Index(), p});
749 const auto it = conditional_equalities_.find({l.
NegatedIndex(), p});
750 if (it != conditional_equalities_.end()) {
751 for (
const IntegerVariable x : it->second) {
752 ProcessNewProduct(p, l.
Negated(), x);
760std::pair<IntegerVariable, IntegerVariable> Canonicalize(IntegerVariable a,
762 if (a <
b)
return {a,
b};
766double GetLiteralLpValue(
768 const util_intops::StrongVector<IntegerVariable, double>& lp_values) {
775void ProductDetector::UpdateRLTMaps(
776 const util_intops::StrongVector<IntegerVariable, double>& lp_values,
777 IntegerVariable var1,
double lp1, IntegerVariable var2,
double lp2,
778 IntegerVariable bound_var,
double bound_lp) {
781 if (bound_lp > lp1 && bound_lp > lp2)
return;
783 const auto [it, inserted] =
784 bool_rlt_ubs_.
insert({Canonicalize(var1, var2), bound_var});
787 if (!inserted && bound_lp < GetLiteralLpValue(it->second, lp_values)) {
788 it->second = bound_var;
792 if (lp1 * lp2 > bound_lp + 1e-4) {
793 bool_rlt_candidates_[var1].push_back(var2);
794 bool_rlt_candidates_[var2].push_back(var1);
800 absl::Span<const IntegerVariable> lp_vars,
805 bool_rlt_ubs_.clear();
809 bool_rlt_candidates_.clear();
810 const int size = ternary_clauses_with_view_.size();
811 if (size == 0)
return;
813 is_in_lp_vars_.resize(integer_trail_->NumIntegerVariables().value());
814 for (
const IntegerVariable var : lp_vars) is_in_lp_vars_.Set(var);
816 for (
int i = 0;
i < size;
i += 3) {
817 const IntegerVariable var1 = ternary_clauses_with_view_[
i];
818 const IntegerVariable var2 = ternary_clauses_with_view_[
i + 1];
819 const IntegerVariable var3 = ternary_clauses_with_view_[
i + 2];
827 const double lp1 = GetLiteralLpValue(var1, lp_values);
828 const double lp2 = GetLiteralLpValue(var2, lp_values);
829 const double lp3 = GetLiteralLpValue(var3, lp_values);
832 1.0 - lp2, var3, lp3);
834 1.0 - lp3, var2, lp2);
836 1.0 - lp3, var1, lp1);
841 for (
const IntegerVariable var : lp_vars) is_in_lp_vars_.ClearBucket(var);
const std::vector< Literal > & DirectImplications(Literal literal)
void Add(IntegerVariable var, const std::vector< ValueLiteralPair > &encoding, int exactly_one_index)
const absl::btree_map< int, std::vector< ValueLiteralPair > > & Get(IntegerVariable var)
Returns an empty map if there is no such encoding.
const std::vector< IntegerVariable > & GetElementEncodedVariables() const
Get an unsorted set of variables appearing in element encodings.
bool ProcessIntegerTrail(Literal first_decision)
~ImpliedBounds()
Just display some global statistics on destruction.
void AddLiteralImpliesVarEqValue(Literal literal, IntegerVariable var, IntegerValue value)
Adds literal => var == value.
bool Add(Literal literal, IntegerLiteral integer_literal)
const std::vector< ImpliedBoundEntry > & GetImpliedBounds(IntegerVariable var)
const std::vector< ValueLiteralPair > & FullDomainEncoding(IntegerVariable var) const
IntegerVariable GetLiteralView(Literal lit) const
bool VariableIsFullyEncoded(IntegerVariable var) const
ABSL_MUST_USE_RESULT bool AddLiteralTerm(Literal lit, IntegerValue coeff=IntegerValue(1))
void AddTerm(IntegerVariable var, IntegerValue coeff)
void AddConstant(IntegerValue value)
Adds the corresponding term to the current linear expression.
void Clear()
Clears all added terms and constants. Keeps the original bounds.
LiteralIndex NegatedIndex() const
LiteralIndex Index() const
bool TryToLinearize(const AffineExpression &left, const AffineExpression &right, LinearConstraintBuilder *builder)
std::vector< LiteralValueValue > TryToDecompose(const AffineExpression &left, const AffineExpression &right)
void InitializeBooleanRLTCuts(absl::Span< const IntegerVariable > lp_vars, const util_intops::StrongVector< IntegerVariable, double > &lp_values)
void ProcessTernaryClause(absl::Span< const Literal > ternary_clause)
void ProcessTernaryExactlyOne(absl::Span< const Literal > ternary_exo)
void ProcessImplicationGraph(BinaryImplicationGraph *graph)
Utility function to process a bunch of implication all at once.
LiteralIndex GetProduct(Literal a, Literal b) const
Query function mainly used for testing.
bool ProductIsLinearizable(IntegerVariable a, IntegerVariable b) const
void ProcessConditionalEquality(Literal l, IntegerVariable x, IntegerVariable y)
void ProcessConditionalZero(Literal l, IntegerVariable p)
void ProcessTrailAtLevelOne()
ProductDetector(Model *model)
void ProcessBinaryClause(absl::Span< const Literal > binary_clause)
Simple class to add statistics by name and print them at the end.
iterator insert(const_iterator pos, const value_type &x)
constexpr IntegerValue kMaxIntegerValue(std::numeric_limits< IntegerValue::ValueType >::max() - 1)
std::vector< LiteralValueValue > TryToReconcileSize2Encodings(const AffineExpression &left, const AffineExpression &right, IntegerEncoder *integer_encoder)
const LiteralIndex kNoLiteralIndex(-1)
std::vector< IntegerVariable > NegationOf(absl::Span< const IntegerVariable > vars)
Returns the vector of the negated variables.
const IntegerVariable kNoIntegerVariable(-1)
IntegerVariable PositiveVariable(IntegerVariable i)
std::vector< LiteralValueValue > TryToReconcileEncodings(const AffineExpression &size2_affine, const AffineExpression &affine, absl::Span< const ValueLiteralPair > affine_var_encoding, bool put_affine_left_in_result, IntegerEncoder *integer_encoder)
bool VariableIsPositive(IntegerVariable i)
In SWIG mode, we don't want anything besides these top-level includes.
IntegerValue ValueAt(IntegerValue var_value) const
Returns the value of this affine expression given its variable value.
static IntegerLiteral GreaterOrEqual(IntegerVariable i, IntegerValue bound)