26#include "absl/container/btree_map.h"
27#include "absl/container/flat_hash_map.h"
28#include "absl/container/flat_hash_set.h"
29#include "absl/log/check.h"
30#include "absl/log/log.h"
31#include "absl/log/vlog_is_on.h"
32#include "absl/meta/type_traits.h"
33#include "absl/types/span.h"
55 if (!VLOG_IS_ON(1))
return;
56 if (shared_stats_ ==
nullptr)
return;
57 std::vector<std::pair<std::string, int64_t>> stats;
58 stats.push_back({
"implied_bound/num_deductions", num_deductions_});
59 stats.push_back({
"implied_bound/num_stored", bounds_.size()});
61 {
"implied_bound/num_stored_with_view", num_enqueued_in_var_to_bounds_});
62 shared_stats_->AddStats(stats);
66 if (!parameters_.use_implied_bounds())
return true;
67 const IntegerVariable var = integer_literal.
var;
71 const IntegerValue root_lb = integer_trail_->LevelZeroLowerBound(var);
72 if (integer_literal.
bound <= root_lb)
return true;
78 if (root_lb + 1 >= integer_trail_->LevelZeroUpperBound(var))
return true;
81 const auto key = std::make_pair(literal.
Index(), var);
82 auto insert_result = bounds_.insert({key, integer_literal.
bound});
83 if (!insert_result.second) {
84 if (insert_result.first->second < integer_literal.
bound) {
85 insert_result.first->second = integer_literal.
bound;
93 if (integer_trail_->LevelZeroUpperBound(var) == integer_literal.
bound) {
98 if (it != bounds_.end() && it->second == -integer_literal.
bound) {
110 const auto it = bounds_.find(std::make_pair(literal.
NegatedIndex(), var));
111 if (it != bounds_.end()) {
112 if (it->second <= root_lb) {
117 const IntegerValue deduction =
118 std::min(integer_literal.
bound, it->second);
119 DCHECK_GT(deduction, root_lb);
122 if (!integer_trail_->RootLevelEnqueue(
127 VLOG(2) <<
"Deduction old: "
129 var, integer_trail_->LevelZeroLowerBound(var))
134 if (it->second == deduction) {
137 if (integer_literal.
bound == deduction) {
138 bounds_.erase(std::make_pair(literal.
Index(), var));
148 if (parameters_.linearization_level() == 0)
return true;
149 if (parameters_.cut_level() == 0)
return true;
155 if (var_to_bounds_.size() <= var) {
156 var_to_bounds_.resize(var.value() + 1);
157 has_implied_bounds_.Resize(var + 1);
159 ++num_enqueued_in_var_to_bounds_;
160 has_implied_bounds_.Set(var);
161 var_to_bounds_[var].emplace_back(integer_encoder_->GetLiteralView(literal),
162 integer_literal.
bound);
163 }
else if (integer_encoder_->GetLiteralView(literal.
Negated()) !=
165 if (var_to_bounds_.size() <= var) {
166 var_to_bounds_.resize(var.value() + 1);
167 has_implied_bounds_.Resize(var + 1);
169 ++num_enqueued_in_var_to_bounds_;
170 has_implied_bounds_.Set(var);
171 var_to_bounds_[var].emplace_back(
173 integer_literal.
bound);
179 IntegerVariable var) {
180 if (var >= var_to_bounds_.size())
return empty_implied_bounds_;
187 std::vector<ImpliedBoundEntry>& ref = var_to_bounds_[var];
188 const IntegerValue root_lb = integer_trail_->LevelZeroLowerBound(var);
190 if (entry.lower_bound <= root_lb)
continue;
191 ref[new_size++] = entry;
193 ref.resize(new_size);
200 IntegerValue value) {
205 literal_to_var_to_value_[literal.
Index()][var] = value;
209 if (!parameters_.use_implied_bounds())
return true;
211 CHECK_EQ(sat_solver_->CurrentDecisionLevel(), 1);
212 tmp_integer_literals_.clear();
213 integer_trail_->AppendNewBounds(&tmp_integer_literals_);
215 if (!
Add(first_decision, lit))
return false;
221 const std::vector<ValueLiteralPair>& encoding,
222 int exactly_one_index) {
223 if (!var_to_index_to_element_encodings_.contains(var)) {
224 element_encoded_variables_.push_back(var);
226 var_to_index_to_element_encodings_[var][exactly_one_index] = encoding;
229const absl::btree_map<int, std::vector<ValueLiteralPair>>&
231 const auto& it = var_to_index_to_element_encodings_.find(var);
232 if (it == var_to_index_to_element_encodings_.end()) {
233 return empty_element_encoding_;
239const std::vector<IntegerVariable>&
241 return element_encoded_variables_;
261 absl::Span<const ValueLiteralPair> affine_var_encoding,
262 bool put_affine_left_in_result,
IntegerEncoder* integer_encoder) {
263 IntegerVariable binary = size2_affine.
var;
264 std::vector<LiteralValueValue> terms;
266 const std::vector<ValueLiteralPair>& size2_enc =
271 if (size2_enc.size() != 2)
return terms;
273 Literal lit0 = size2_enc[0].literal;
274 IntegerValue value0 = size2_affine.
ValueAt(size2_enc[0].value);
275 Literal lit1 = size2_enc[1].literal;
276 IntegerValue value1 = size2_affine.
ValueAt(size2_enc[1].value);
278 for (
const auto& [unused, candidate_literal] : affine_var_encoding) {
279 if (candidate_literal == lit1) {
280 std::swap(lit0, lit1);
281 std::swap(value0, value1);
283 if (candidate_literal != lit0)
continue;
286 for (
const auto& [value, literal] : affine_var_encoding) {
287 const IntegerValue size_2_value = literal == lit0 ? value0 : value1;
288 const IntegerValue affine_value = affine.
ValueAt(value);
289 if (put_affine_left_in_result) {
290 terms.push_back({literal, affine_value, size_2_value});
292 terms.push_back({literal, size_2_value, affine_value});
306 std::vector<LiteralValueValue> terms;
311 const std::vector<ValueLiteralPair> left_enc =
313 const std::vector<ValueLiteralPair> right_enc =
315 if (left_enc.size() != 2 || right_enc.size() != 2) {
316 VLOG(2) <<
"encodings are not fully propagated";
320 const Literal left_lit0 = left_enc[0].literal;
321 const IntegerValue left_value0 = left.
ValueAt(left_enc[0].value);
322 const Literal left_lit1 = left_enc[1].literal;
323 const IntegerValue left_value1 = left.
ValueAt(left_enc[1].value);
325 const Literal right_lit0 = right_enc[0].literal;
326 const IntegerValue right_value0 = right.
ValueAt(right_enc[0].value);
327 const Literal right_lit1 = right_enc[1].literal;
328 const IntegerValue right_value1 = right.
ValueAt(right_enc[1].value);
330 if (left_lit0 == right_lit0 || left_lit0 == right_lit1.
Negated()) {
331 terms.push_back({left_lit0, left_value0, right_value0});
332 terms.push_back({left_lit0.
Negated(), left_value1, right_value1});
333 }
else if (left_lit0 == right_lit1 || left_lit0 == right_lit0.
Negated()) {
334 terms.push_back({left_lit0, left_value0, right_value1});
335 terms.push_back({left_lit0.
Negated(), left_value1, right_value0});
336 }
else if (left_lit1 == right_lit1 || left_lit1 == right_lit0.
Negated()) {
337 terms.push_back({left_lit1.
Negated(), left_value0, right_value0});
338 terms.push_back({left_lit1, left_value1, right_value1});
339 }
else if (left_lit1 == right_lit0 || left_lit1 == right_lit1.
Negated()) {
340 terms.push_back({left_lit1.
Negated(), left_value0, right_value1});
341 terms.push_back({left_lit1, left_value1, right_value0});
343 VLOG(3) <<
"Complex size 2 encoding case, need to scan exactly_ones";
351 if (integer_trail_->IsFixed(left) || integer_trail_->IsFixed(right)) {
356 const absl::btree_map<int, std::vector<ValueLiteralPair>>& left_encodings =
357 element_encodings_->Get(left.
var);
360 const absl::btree_map<int, std::vector<ValueLiteralPair>>& right_encodings =
361 element_encodings_->Get(right.
var);
363 std::vector<int> compatible_keys;
364 for (
const auto& [index, encoding] : left_encodings) {
365 if (right_encodings.contains(index)) {
366 compatible_keys.push_back(index);
370 if (compatible_keys.empty()) {
371 if (integer_trail_->InitialVariableDomain(left.
var).Size() == 2) {
372 for (
const auto& [index, right_encoding] : right_encodings) {
374 left, right, right_encoding,
375 false, integer_encoder_);
376 if (!result.empty()) {
381 if (integer_trail_->InitialVariableDomain(right.
var).Size() == 2) {
382 for (
const auto& [index, left_encoding] : left_encodings) {
384 right, left, left_encoding,
385 true, integer_encoder_);
386 if (!result.empty()) {
391 if (integer_trail_->InitialVariableDomain(left.
var).Size() == 2 &&
392 integer_trail_->InitialVariableDomain(right.
var).Size() == 2) {
393 const std::vector<LiteralValueValue> result =
395 if (!result.empty()) {
402 if (compatible_keys.size() > 1) {
403 VLOG(3) <<
"More than one exactly_one involved in the encoding of the two "
408 const int min_index =
409 *std::min_element(compatible_keys.begin(), compatible_keys.end());
412 const std::vector<ValueLiteralPair>& left_encoding =
413 left_encodings.at(min_index);
414 const std::vector<ValueLiteralPair>& right_encoding =
415 right_encodings.at(min_index);
416 DCHECK_EQ(left_encoding.size(), right_encoding.size());
419 std::vector<LiteralValueValue> terms;
420 for (
int i = 0;
i < left_encoding.size(); ++
i) {
421 const Literal literal = left_encoding[
i].literal;
422 DCHECK_EQ(literal, right_encoding[
i].literal);
423 terms.push_back({literal, left.
ValueAt(left_encoding[
i].value),
424 right.
ValueAt(right_encoding[
i].value)});
435 DCHECK(builder !=
nullptr);
438 if (integer_trail_->IsFixed(left)) {
439 if (integer_trail_->IsFixed(right)) {
440 builder->
AddConstant(integer_trail_->FixedValue(left) *
441 integer_trail_->FixedValue(right));
444 builder->
AddTerm(right, integer_trail_->FixedValue(left));
448 if (integer_trail_->IsFixed(right)) {
449 builder->
AddTerm(left, integer_trail_->FixedValue(right));
458 const IntegerValue left_coeff =
460 const IntegerValue right_coeff =
463 left_coeff * right_coeff + left.
constant * right_coeff +
469 const std::vector<LiteralValueValue> decomposition =
471 if (decomposition.empty())
return false;
476 std::min(min_coefficient, term.left_value * term.right_value);
479 const IntegerValue coefficient =
480 term.left_value * term.right_value - min_coefficient;
481 if (coefficient == 0)
continue;
492 model->GetOrCreate<
SatParameters>()->detect_linearized_product() &&
493 model->GetOrCreate<
SatParameters>()->linearization_level() > 1),
494 rlt_enabled_(model->GetOrCreate<
SatParameters>()->add_rlt_cuts() &&
497 sat_solver_(model->GetOrCreate<
SatSolver>()),
498 trail_(model->GetOrCreate<
Trail>()),
504 if (!VLOG_IS_ON(1))
return;
505 if (shared_stats_ ==
nullptr)
return;
506 std::vector<std::pair<std::string, int64_t>> stats;
508 {
"product_detector/num_processed_binary", num_processed_binary_});
510 {
"product_detector/num_processed_exactly_one", num_processed_exo_});
512 {
"product_detector/num_processed_ternary", num_processed_ternary_});
513 stats.push_back({
"product_detector/num_trail_updates", num_trail_updates_});
514 stats.push_back({
"product_detector/num_products", num_products_});
515 stats.push_back({
"product_detector/num_conditional_equalities",
516 num_conditional_equalities_});
518 {
"product_detector/num_conditional_zeros", num_conditional_zeros_});
519 stats.push_back({
"product_detector/num_int_products", num_int_products_});
520 shared_stats_->AddStats(stats);
524 absl::Span<const Literal> ternary_clause) {
525 if (ternary_clause.size() != 3)
return;
526 ++num_processed_ternary_;
528 if (rlt_enabled_) ProcessTernaryClauseForRLT(ternary_clause);
529 if (!enabled_)
return;
531 candidates_[GetKey(ternary_clause[0].Index(), ternary_clause[1].Index())]
532 .push_back(ternary_clause[2].Index());
533 candidates_[GetKey(ternary_clause[0].Index(), ternary_clause[2].Index())]
534 .push_back(ternary_clause[1].Index());
535 candidates_[GetKey(ternary_clause[1].Index(), ternary_clause[2].Index())]
536 .push_back(ternary_clause[0].Index());
540 for (
const Literal l : ternary_clause) {
541 if (l.Index() >= seen_.size()) seen_.resize(l.Index() + 1);
542 seen_[l.Index()] =
true;
547void ProductDetector::ProcessTernaryClauseForRLT(
548 absl::Span<const Literal> ternary_clause) {
549 const int old_size = ternary_clauses_with_view_.size();
550 for (
const Literal l : ternary_clause) {
551 const IntegerVariable var =
554 ternary_clauses_with_view_.resize(old_size);
557 ternary_clauses_with_view_.push_back(l.IsPositive() ? var
563 absl::Span<const Literal> ternary_exo) {
564 if (ternary_exo.size() != 3)
return;
565 ++num_processed_exo_;
567 if (rlt_enabled_) ProcessTernaryClauseForRLT(ternary_exo);
568 if (!enabled_)
return;
570 ProcessNewProduct(ternary_exo[0].Index(), ternary_exo[1].NegatedIndex(),
571 ternary_exo[2].NegatedIndex());
572 ProcessNewProduct(ternary_exo[1].Index(), ternary_exo[0].NegatedIndex(),
573 ternary_exo[2].NegatedIndex());
574 ProcessNewProduct(ternary_exo[2].Index(), ternary_exo[0].NegatedIndex(),
575 ternary_exo[1].NegatedIndex());
581 absl::Span<const Literal> binary_clause) {
582 if (!enabled_)
return;
583 if (binary_clause.size() != 2)
return;
584 ++num_processed_binary_;
585 const std::array<LiteralIndex, 2> key =
586 GetKey(binary_clause[0].NegatedIndex(), binary_clause[1].NegatedIndex());
587 std::array<LiteralIndex, 3> ternary;
588 for (
const LiteralIndex l : candidates_[key]) {
592 std::sort(ternary.begin(), ternary.end());
593 const int l_index = ternary[0] == l ? 0 : ternary[1] == l ? 1 : 2;
594 std::bitset<3>& bs = detector_[ternary];
595 if (bs[l_index])
continue;
597 if (bs[0] && bs[1] && l_index != 2) {
598 ProcessNewProduct(ternary[2],
Literal(ternary[0]).NegatedIndex(),
599 Literal(ternary[1]).NegatedIndex());
601 if (bs[0] && bs[2] && l_index != 1) {
602 ProcessNewProduct(ternary[1],
Literal(ternary[0]).NegatedIndex(),
603 Literal(ternary[2]).NegatedIndex());
605 if (bs[1] && bs[2] && l_index != 0) {
606 ProcessNewProduct(ternary[0],
Literal(ternary[1]).NegatedIndex(),
607 Literal(ternary[2]).NegatedIndex());
613 if (!enabled_)
return;
614 for (LiteralIndex a(0); a < seen_.size(); ++a) {
615 if (!seen_[a])
continue;
616 if (trail_->Assignment().LiteralIsAssigned(
Literal(a)))
continue;
625 if (!enabled_)
return;
626 if (trail_->CurrentDecisionLevel() != 1)
return;
627 ++num_trail_updates_;
635 const int current_index = trail_->
Index();
643 const auto it = products_.find(GetKey(a.
Index(),
b.Index()));
648std::array<LiteralIndex, 2> ProductDetector::GetKey(LiteralIndex a,
649 LiteralIndex
b)
const {
650 std::array<LiteralIndex, 2> key{a,
b};
651 if (key[0] > key[1]) std::swap(key[0], key[1]);
655void ProductDetector::ProcessNewProduct(LiteralIndex p, LiteralIndex a,
659 products_[GetKey(a,
b)] = p;
663 GetKey(Literal(a).IsPositive() ? a : Literal(a).NegatedIndex(),
664 Literal(
b).IsPositive() ?
b : Literal(
b).NegatedIndex()));
668 IntegerVariable
b)
const {
669 if (a ==
b)
return true;
674 if (integer_trail_->InitialVariableDomain(a).Size() != 2)
return false;
675 if (integer_trail_->InitialVariableDomain(
b).Size() != 2)
return false;
677 const LiteralIndex la =
679 a, integer_trail_->LevelZeroUpperBound(a)));
682 const LiteralIndex lb =
684 b, integer_trail_->LevelZeroUpperBound(
b)));
688 return has_product_.contains(
689 GetKey(
Literal(la).IsPositive() ? la :
Literal(la).NegatedIndex(),
694 IntegerVariable
b)
const {
700void ProductDetector::ProcessNewProduct(IntegerVariable p,
Literal l,
709 int_products_[{l.
Index(), x}] = p;
714 ++num_conditional_equalities_;
718 for (
int i = 0;
i < 2; ++
i) {
727 std::vector<IntegerVariable>& others =
728 conditional_equalities_[{l.
Index(), x}];
729 for (
const IntegerVariable o : others) {
738 if (conditional_zeros_.contains({l.NegatedIndex(), x})) {
739 ProcessNewProduct(x, l, y);
747 ++num_conditional_zeros_;
749 auto [_, inserted] = conditional_zeros_.insert({l.
Index(), p});
751 const auto it = conditional_equalities_.find({l.
NegatedIndex(), p});
752 if (it != conditional_equalities_.end()) {
753 for (
const IntegerVariable x : it->second) {
754 ProcessNewProduct(p, l.
Negated(), x);
762std::pair<IntegerVariable, IntegerVariable> Canonicalize(IntegerVariable a,
764 if (a <
b)
return {a,
b};
768double GetLiteralLpValue(
770 const util_intops::StrongVector<IntegerVariable, double>& lp_values) {
777void ProductDetector::UpdateRLTMaps(
778 const util_intops::StrongVector<IntegerVariable, double>& lp_values,
779 IntegerVariable var1,
double lp1, IntegerVariable var2,
double lp2,
780 IntegerVariable bound_var,
double bound_lp) {
783 if (bound_lp > lp1 && bound_lp > lp2)
return;
785 const auto [it, inserted] =
786 bool_rlt_ubs_.
insert({Canonicalize(var1, var2), bound_var});
789 if (!inserted && bound_lp < GetLiteralLpValue(it->second, lp_values)) {
790 it->second = bound_var;
794 if (lp1 * lp2 > bound_lp + 1e-4) {
795 bool_rlt_candidates_[var1].push_back(var2);
796 bool_rlt_candidates_[var2].push_back(var1);
802 absl::Span<const IntegerVariable> lp_vars,
807 bool_rlt_ubs_.clear();
811 bool_rlt_candidates_.clear();
812 const int size = ternary_clauses_with_view_.size();
813 if (size == 0)
return;
815 is_in_lp_vars_.resize(integer_trail_->NumIntegerVariables().value());
816 for (
const IntegerVariable var : lp_vars) is_in_lp_vars_.Set(var);
818 for (
int i = 0;
i < size;
i += 3) {
819 const IntegerVariable var1 = ternary_clauses_with_view_[
i];
820 const IntegerVariable var2 = ternary_clauses_with_view_[
i + 1];
821 const IntegerVariable var3 = ternary_clauses_with_view_[
i + 2];
829 const double lp1 = GetLiteralLpValue(var1, lp_values);
830 const double lp2 = GetLiteralLpValue(var2, lp_values);
831 const double lp3 = GetLiteralLpValue(var3, lp_values);
834 1.0 - lp2, var3, lp3);
836 1.0 - lp3, var2, lp2);
838 1.0 - lp3, var1, lp1);
843 for (
const IntegerVariable var : lp_vars) is_in_lp_vars_.ClearBucket(var);
const std::vector< Literal > & DirectImplications(Literal literal)
void Add(IntegerVariable var, const std::vector< ValueLiteralPair > &encoding, int exactly_one_index)
const absl::btree_map< int, std::vector< ValueLiteralPair > > & Get(IntegerVariable var)
Returns an empty map if there is no such encoding.
const std::vector< IntegerVariable > & GetElementEncodedVariables() const
Get an unsorted set of variables appearing in element encodings.
bool ProcessIntegerTrail(Literal first_decision)
~ImpliedBounds()
Just display some global statistics on destruction.
void AddLiteralImpliesVarEqValue(Literal literal, IntegerVariable var, IntegerValue value)
Adds literal => var == value.
bool Add(Literal literal, IntegerLiteral integer_literal)
const std::vector< ImpliedBoundEntry > & GetImpliedBounds(IntegerVariable var)
const std::vector< ValueLiteralPair > & FullDomainEncoding(IntegerVariable var) const
IntegerVariable GetLiteralView(Literal lit) const
bool VariableIsFullyEncoded(IntegerVariable var) const
ABSL_MUST_USE_RESULT bool AddLiteralTerm(Literal lit, IntegerValue coeff=IntegerValue(1))
void AddTerm(IntegerVariable var, IntegerValue coeff)
void AddConstant(IntegerValue value)
Adds the corresponding term to the current linear expression.
void Clear()
Clears all added terms and constants. Keeps the original bounds.
LiteralIndex NegatedIndex() const
LiteralIndex Index() const
bool TryToLinearize(const AffineExpression &left, const AffineExpression &right, LinearConstraintBuilder *builder)
std::vector< LiteralValueValue > TryToDecompose(const AffineExpression &left, const AffineExpression &right)
void InitializeBooleanRLTCuts(absl::Span< const IntegerVariable > lp_vars, const util_intops::StrongVector< IntegerVariable, double > &lp_values)
void ProcessTernaryClause(absl::Span< const Literal > ternary_clause)
void ProcessTernaryExactlyOne(absl::Span< const Literal > ternary_exo)
void ProcessImplicationGraph(BinaryImplicationGraph *graph)
Utility function to process a bunch of implication all at once.
LiteralIndex GetProduct(Literal a, Literal b) const
Query function mainly used for testing.
bool ProductIsLinearizable(IntegerVariable a, IntegerVariable b) const
void ProcessConditionalEquality(Literal l, IntegerVariable x, IntegerVariable y)
void ProcessConditionalZero(Literal l, IntegerVariable p)
void ProcessTrailAtLevelOne()
ProductDetector(Model *model)
void ProcessBinaryClause(absl::Span< const Literal > binary_clause)
Simple class to add statistics by name and print them at the end.
iterator insert(const_iterator pos, const value_type &x)
constexpr IntegerValue kMaxIntegerValue(std::numeric_limits< IntegerValue::ValueType >::max() - 1)
std::vector< LiteralValueValue > TryToReconcileSize2Encodings(const AffineExpression &left, const AffineExpression &right, IntegerEncoder *integer_encoder)
const LiteralIndex kNoLiteralIndex(-1)
std::vector< IntegerVariable > NegationOf(absl::Span< const IntegerVariable > vars)
Returns the vector of the negated variables.
const IntegerVariable kNoIntegerVariable(-1)
IntegerVariable PositiveVariable(IntegerVariable i)
std::vector< LiteralValueValue > TryToReconcileEncodings(const AffineExpression &size2_affine, const AffineExpression &affine, absl::Span< const ValueLiteralPair > affine_var_encoding, bool put_affine_left_in_result, IntegerEncoder *integer_encoder)
bool VariableIsPositive(IntegerVariable i)
In SWIG mode, we don't want anything besides these top-level includes.
IntegerValue ValueAt(IntegerValue var_value) const
Returns the value of this affine expression given its variable value.
static IntegerLiteral GreaterOrEqual(IntegerVariable i, IntegerValue bound)