26#include "absl/container/btree_map.h"
27#include "absl/container/flat_hash_map.h"
28#include "absl/container/flat_hash_set.h"
29#include "absl/log/check.h"
30#include "absl/meta/type_traits.h"
31#include "absl/strings/str_cat.h"
32#include "absl/types/span.h"
41#include "ortools/sat/sat_parameters.pb.h"
53 if (!VLOG_IS_ON(1))
return;
54 if (shared_stats_ ==
nullptr)
return;
55 std::vector<std::pair<std::string, int64_t>> stats;
56 stats.push_back({
"implied_bound/num_deductions", num_deductions_});
57 stats.push_back({
"implied_bound/num_stored", bounds_.size()});
59 {
"implied_bound/num_stored_with_view", num_enqueued_in_var_to_bounds_});
64 if (!parameters_.use_implied_bounds())
return true;
65 const IntegerVariable
var = integer_literal.
var;
70 if (integer_literal.
bound <= root_lb)
return true;
79 const auto key = std::make_pair(
literal.Index(),
var);
80 auto insert_result = bounds_.insert({key, integer_literal.
bound});
81 if (!insert_result.second) {
82 if (insert_result.first->second < integer_literal.
bound) {
83 insert_result.first->second = integer_literal.
bound;
96 if (it != bounds_.end() && it->second == -integer_literal.
bound) {
108 const auto it = bounds_.find(std::make_pair(
literal.NegatedIndex(),
var));
109 if (it != bounds_.end()) {
110 if (it->second <= root_lb) {
115 const IntegerValue deduction =
116 std::min(integer_literal.
bound, it->second);
117 DCHECK_GT(deduction, root_lb);
125 VLOG(2) <<
"Deduction old: "
132 if (it->second == deduction) {
135 if (integer_literal.
bound == deduction) {
136 bounds_.erase(std::make_pair(
literal.Index(),
var));
146 if (parameters_.linearization_level() == 0)
return true;
147 if (parameters_.cut_level() == 0)
return true;
153 if (var_to_bounds_.size() <=
var) {
154 var_to_bounds_.resize(
var.value() + 1);
157 ++num_enqueued_in_var_to_bounds_;
158 has_implied_bounds_.
Set(
var);
160 integer_literal.
bound,
true});
163 if (var_to_bounds_.size() <=
var) {
164 var_to_bounds_.resize(
var.value() + 1);
167 ++num_enqueued_in_var_to_bounds_;
168 has_implied_bounds_.
Set(
var);
169 var_to_bounds_[
var].push_back(
171 integer_literal.
bound,
false});
177 IntegerVariable
var) {
178 if (
var >= var_to_bounds_.size())
return empty_implied_bounds_;
185 std::vector<ImpliedBoundEntry>& ref = var_to_bounds_[
var];
188 if (entry.lower_bound <= root_lb)
continue;
189 ref[new_size++] = entry;
191 ref.resize(new_size);
198 IntegerValue
value) {
207 if (!parameters_.use_implied_bounds())
return true;
210 tmp_integer_literals_.clear();
213 if (!
Add(first_decision,
lit))
return false;
219 const std::vector<ValueLiteralPair>& encoding,
220 int exactly_one_index) {
221 if (!var_to_index_to_element_encodings_.contains(
var)) {
222 element_encoded_variables_.push_back(
var);
224 var_to_index_to_element_encodings_[
var][exactly_one_index] = encoding;
227const absl::btree_map<int, std::vector<ValueLiteralPair>>&
229 const auto& it = var_to_index_to_element_encodings_.find(
var);
230 if (it == var_to_index_to_element_encodings_.end()) {
231 return empty_element_encoding_;
237const std::vector<IntegerVariable>&
239 return element_encoded_variables_;
259 absl::Span<const ValueLiteralPair> affine_var_encoding,
260 bool put_affine_left_in_result,
IntegerEncoder* integer_encoder) {
261 IntegerVariable binary = size2_affine.
var;
262 std::vector<LiteralValueValue> terms;
264 const std::vector<ValueLiteralPair>& size2_enc =
269 if (size2_enc.size() != 2)
return terms;
271 Literal lit0 = size2_enc[0].literal;
272 IntegerValue value0 = size2_affine.
ValueAt(size2_enc[0].
value);
273 Literal lit1 = size2_enc[1].literal;
274 IntegerValue value1 = size2_affine.
ValueAt(size2_enc[1].
value);
276 for (
const auto& [unused, candidate_literal] : affine_var_encoding) {
277 if (candidate_literal == lit1) {
278 std::swap(lit0, lit1);
279 std::swap(value0, value1);
281 if (candidate_literal != lit0)
continue;
284 for (
const auto& [
value,
literal] : affine_var_encoding) {
285 const IntegerValue size_2_value =
literal == lit0 ? value0 : value1;
286 const IntegerValue affine_value = affine.
ValueAt(
value);
287 if (put_affine_left_in_result) {
288 terms.push_back({
literal, affine_value, size_2_value});
290 terms.push_back({
literal, size_2_value, affine_value});
304 std::vector<LiteralValueValue> terms;
309 const std::vector<ValueLiteralPair> left_enc =
311 const std::vector<ValueLiteralPair> right_enc =
313 if (left_enc.size() != 2 || right_enc.size() != 2) {
314 VLOG(2) <<
"encodings are not fully propagated";
318 const Literal left_lit0 = left_enc[0].literal;
319 const IntegerValue left_value0 = left.
ValueAt(left_enc[0].
value);
320 const Literal left_lit1 = left_enc[1].literal;
321 const IntegerValue left_value1 = left.
ValueAt(left_enc[1].
value);
323 const Literal right_lit0 = right_enc[0].literal;
324 const IntegerValue right_value0 = right.
ValueAt(right_enc[0].
value);
325 const Literal right_lit1 = right_enc[1].literal;
326 const IntegerValue right_value1 = right.
ValueAt(right_enc[1].
value);
328 if (left_lit0 == right_lit0 || left_lit0 == right_lit1.
Negated()) {
329 terms.push_back({left_lit0, left_value0, right_value0});
330 terms.push_back({left_lit0.
Negated(), left_value1, right_value1});
331 }
else if (left_lit0 == right_lit1 || left_lit0 == right_lit0.
Negated()) {
332 terms.push_back({left_lit0, left_value0, right_value1});
333 terms.push_back({left_lit0.
Negated(), left_value1, right_value0});
334 }
else if (left_lit1 == right_lit1 || left_lit1 == right_lit0.
Negated()) {
335 terms.push_back({left_lit1.
Negated(), left_value0, right_value0});
336 terms.push_back({left_lit1, left_value1, right_value1});
337 }
else if (left_lit1 == right_lit0 || left_lit1 == right_lit1.
Negated()) {
338 terms.push_back({left_lit1.
Negated(), left_value0, right_value1});
339 terms.push_back({left_lit1, left_value1, right_value0});
341 VLOG(3) <<
"Complex size 2 encoding case, need to scan exactly_ones";
349 if (integer_trail_->
IsFixed(left) || integer_trail_->
IsFixed(right)) {
354 const absl::btree_map<int, std::vector<ValueLiteralPair>>& left_encodings =
355 element_encodings_->
Get(left.
var);
358 const absl::btree_map<int, std::vector<ValueLiteralPair>>& right_encodings =
359 element_encodings_->
Get(right.
var);
361 std::vector<int> compatible_keys;
362 for (
const auto& [
index, encoding] : left_encodings) {
363 if (right_encodings.contains(
index)) {
364 compatible_keys.push_back(
index);
368 if (compatible_keys.empty()) {
370 for (
const auto& [
index, right_encoding] : right_encodings) {
372 left, right, right_encoding,
373 false, integer_encoder_);
374 if (!result.empty()) {
380 for (
const auto& [
index, left_encoding] : left_encodings) {
382 right, left, left_encoding,
383 true, integer_encoder_);
384 if (!result.empty()) {
391 const std::vector<LiteralValueValue> result =
393 if (!result.empty()) {
400 if (compatible_keys.size() > 1) {
401 VLOG(3) <<
"More than one exactly_one involved in the encoding of the two "
406 const int min_index =
407 *std::min_element(compatible_keys.begin(), compatible_keys.end());
410 const std::vector<ValueLiteralPair>& left_encoding =
411 left_encodings.at(min_index);
412 const std::vector<ValueLiteralPair>& right_encoding =
413 right_encodings.at(min_index);
414 DCHECK_EQ(left_encoding.size(), right_encoding.size());
417 std::vector<LiteralValueValue> terms;
418 for (
int i = 0;
i < left_encoding.size(); ++
i) {
433 DCHECK(builder !=
nullptr);
436 if (integer_trail_->
IsFixed(left)) {
437 if (integer_trail_->
IsFixed(right)) {
446 if (integer_trail_->
IsFixed(right)) {
456 const IntegerValue left_coeff =
458 const IntegerValue right_coeff =
461 left_coeff * right_coeff + left.
constant * right_coeff +
467 const std::vector<LiteralValueValue> decomposition =
469 if (decomposition.empty())
return false;
474 std::min(min_coefficient, term.left_value * term.right_value);
478 term.left_value * term.right_value - min_coefficient;
490 model->GetOrCreate<SatParameters>()->detect_linearized_product() &&
491 model->GetOrCreate<SatParameters>()->linearization_level() > 1),
492 rlt_enabled_(
model->GetOrCreate<SatParameters>()->add_rlt_cuts() &&
493 model->GetOrCreate<SatParameters>()->linearization_level() >
502 if (!VLOG_IS_ON(1))
return;
503 if (shared_stats_ ==
nullptr)
return;
504 std::vector<std::pair<std::string, int64_t>> stats;
506 {
"product_detector/num_processed_binary", num_processed_binary_});
508 {
"product_detector/num_processed_exactly_one", num_processed_exo_});
510 {
"product_detector/num_processed_ternary", num_processed_ternary_});
511 stats.push_back({
"product_detector/num_trail_updates", num_trail_updates_});
512 stats.push_back({
"product_detector/num_products", num_products_});
513 stats.push_back({
"product_detector/num_conditional_equalities",
514 num_conditional_equalities_});
516 {
"product_detector/num_conditional_zeros", num_conditional_zeros_});
517 stats.push_back({
"product_detector/num_int_products", num_int_products_});
522 absl::Span<const Literal> ternary_clause) {
523 if (ternary_clause.size() != 3)
return;
524 ++num_processed_ternary_;
526 if (rlt_enabled_) ProcessTernaryClauseForRLT(ternary_clause);
527 if (!enabled_)
return;
529 candidates_[GetKey(ternary_clause[0].Index(), ternary_clause[1].Index())]
530 .push_back(ternary_clause[2].Index());
531 candidates_[GetKey(ternary_clause[0].Index(), ternary_clause[2].Index())]
532 .push_back(ternary_clause[1].Index());
533 candidates_[GetKey(ternary_clause[1].Index(), ternary_clause[2].Index())]
534 .push_back(ternary_clause[0].Index());
538 for (
const Literal l : ternary_clause) {
539 if (l.Index() >= seen_.size()) seen_.
resize(l.Index() + 1);
540 seen_[l.Index()] =
true;
545void ProductDetector::ProcessTernaryClauseForRLT(
546 absl::Span<const Literal> ternary_clause) {
547 const int old_size = ternary_clauses_with_view_.size();
548 for (
const Literal l : ternary_clause) {
549 const IntegerVariable
var =
552 ternary_clauses_with_view_.resize(old_size);
555 ternary_clauses_with_view_.push_back(l.IsPositive() ?
var
561 absl::Span<const Literal> ternary_exo) {
562 if (ternary_exo.size() != 3)
return;
563 ++num_processed_exo_;
565 if (rlt_enabled_) ProcessTernaryClauseForRLT(ternary_exo);
566 if (!enabled_)
return;
568 ProcessNewProduct(ternary_exo[0].Index(), ternary_exo[1].NegatedIndex(),
569 ternary_exo[2].NegatedIndex());
570 ProcessNewProduct(ternary_exo[1].Index(), ternary_exo[0].NegatedIndex(),
571 ternary_exo[2].NegatedIndex());
572 ProcessNewProduct(ternary_exo[2].Index(), ternary_exo[0].NegatedIndex(),
573 ternary_exo[1].NegatedIndex());
579 absl::Span<const Literal> binary_clause) {
580 if (!enabled_)
return;
581 if (binary_clause.size() != 2)
return;
582 ++num_processed_binary_;
583 const std::array<LiteralIndex, 2> key =
584 GetKey(binary_clause[0].NegatedIndex(), binary_clause[1].NegatedIndex());
585 std::array<LiteralIndex, 3> ternary;
586 for (
const LiteralIndex l : candidates_[key]) {
590 std::sort(ternary.begin(), ternary.end());
591 const int l_index = ternary[0] == l ? 0 : ternary[1] == l ? 1 : 2;
592 std::bitset<3>& bs = detector_[ternary];
593 if (bs[l_index])
continue;
595 if (bs[0] && bs[1] && l_index != 2) {
596 ProcessNewProduct(ternary[2],
Literal(ternary[0]).NegatedIndex(),
597 Literal(ternary[1]).NegatedIndex());
599 if (bs[0] && bs[2] && l_index != 1) {
600 ProcessNewProduct(ternary[1],
Literal(ternary[0]).NegatedIndex(),
601 Literal(ternary[2]).NegatedIndex());
603 if (bs[1] && bs[2] && l_index != 0) {
604 ProcessNewProduct(ternary[0],
Literal(ternary[1]).NegatedIndex(),
605 Literal(ternary[2]).NegatedIndex());
611 if (!enabled_)
return;
612 for (LiteralIndex
a(0);
a < seen_.size(); ++
a) {
613 if (!seen_[
a])
continue;
623 if (!enabled_)
return;
625 ++num_trail_updates_;
633 const int current_index = trail_->
Index();
641 const auto it = products_.find(GetKey(
a.Index(),
b.Index()));
646std::array<LiteralIndex, 2> ProductDetector::GetKey(LiteralIndex
a,
647 LiteralIndex
b)
const {
648 std::array<LiteralIndex, 2> key{
a,
b};
649 if (key[0] > key[1]) std::swap(key[0], key[1]);
653void ProductDetector::ProcessNewProduct(LiteralIndex p, LiteralIndex
a,
657 products_[GetKey(
a,
b)] = p;
661 GetKey(Literal(
a).IsPositive() ?
a : Literal(
a).NegatedIndex(),
662 Literal(
b).IsPositive() ?
b : Literal(
b).NegatedIndex()));
666 IntegerVariable
b)
const {
667 if (
a ==
b)
return true;
675 const LiteralIndex la =
680 const LiteralIndex lb =
686 return has_product_.contains(
687 GetKey(
Literal(la).IsPositive() ? la :
Literal(la).NegatedIndex(),
692 IntegerVariable
b)
const {
698void ProductDetector::ProcessNewProduct(IntegerVariable p,
Literal l,
707 int_products_[{l.
Index(),
x}] = p;
712 ++num_conditional_equalities_;
716 for (
int i = 0;
i < 2; ++
i) {
725 std::vector<IntegerVariable>& others =
726 conditional_equalities_[{l.
Index(),
x}];
727 for (
const IntegerVariable o : others) {
736 if (conditional_zeros_.contains({l.NegatedIndex(), x})) {
737 ProcessNewProduct(
x, l,
y);
745 ++num_conditional_zeros_;
747 auto [_, inserted] = conditional_zeros_.insert({l.
Index(), p});
749 const auto it = conditional_equalities_.find({l.
NegatedIndex(), p});
750 if (it != conditional_equalities_.end()) {
751 for (
const IntegerVariable
x : it->second) {
752 ProcessNewProduct(p, l.
Negated(),
x);
760std::pair<IntegerVariable, IntegerVariable> Canonicalize(IntegerVariable
a,
762 if (
a <
b)
return {
a,
b};
766double GetLiteralLpValue(
775void ProductDetector::UpdateRLTMaps(
777 IntegerVariable var1,
double lp1, IntegerVariable var2,
double lp2,
778 IntegerVariable bound_var,
double bound_lp) {
781 if (bound_lp > lp1 && bound_lp > lp2)
return;
783 const auto [it, inserted] =
784 bool_rlt_ubs_.insert({Canonicalize(var1, var2), bound_var});
787 if (!inserted && bound_lp < GetLiteralLpValue(it->second, lp_values)) {
788 it->second = bound_var;
792 if (lp1 * lp2 > bound_lp + 1e-4) {
793 bool_rlt_candidates_[var1].push_back(var2);
794 bool_rlt_candidates_[var2].push_back(var1);
800 const absl::flat_hash_map<IntegerVariable, glop::ColIndex>& lp_vars,
805 bool_rlt_ubs_.clear();
809 bool_rlt_candidates_.clear();
810 const int size = ternary_clauses_with_view_.size();
811 for (
int i = 0;
i <
size;
i += 3) {
812 const IntegerVariable var1 = ternary_clauses_with_view_[
i];
813 const IntegerVariable var2 = ternary_clauses_with_view_[
i + 1];
814 const IntegerVariable var3 = ternary_clauses_with_view_[
i + 2];
822 const double lp1 = GetLiteralLpValue(var1, lp_values);
823 const double lp2 = GetLiteralLpValue(var2, lp_values);
824 const double lp3 = GetLiteralLpValue(var3, lp_values);
827 1.0 - lp2, var3, lp3);
829 1.0 - lp3, var2, lp2);
831 1.0 - lp3, var1, lp1);
void Set(IntegerType index)
void Resize(IntegerType size)
void Add(IntegerVariable var, const std::vector< ValueLiteralPair > &encoding, int exactly_one_index)
const absl::btree_map< int, std::vector< ValueLiteralPair > > & Get(IntegerVariable var)
Returns an empty map if there is no such encoding.
const std::vector< IntegerVariable > & GetElementEncodedVariables() const
Get an unsorted set of variables appearing in element encodings.
bool ProcessIntegerTrail(Literal first_decision)
~ImpliedBounds()
Just display some global statistics on destruction.
void AddLiteralImpliesVarEqValue(Literal literal, IntegerVariable var, IntegerValue value)
Adds literal => var == value.
bool Add(Literal literal, IntegerLiteral integer_literal)
const std::vector< ImpliedBoundEntry > & GetImpliedBounds(IntegerVariable var)
const std::vector< ValueLiteralPair > & FullDomainEncoding(IntegerVariable var) const
IntegerVariable GetLiteralView(Literal lit) const
bool VariableIsFullyEncoded(IntegerVariable var) const
LiteralIndex GetAssociatedLiteral(IntegerLiteral i_lit) const
IntegerValue LowerBound(IntegerVariable i) const
Returns the current lower/upper bound of the given integer variable.
IntegerValue UpperBound(IntegerVariable i) const
IntegerValue FixedValue(IntegerVariable i) const
Checks that the variable is fixed and returns its value.
void AppendNewBounds(std::vector< IntegerLiteral > *output) const
bool IsFixed(IntegerVariable i) const
Checks if the variable is fixed.
ABSL_MUST_USE_RESULT bool RootLevelEnqueue(IntegerLiteral i_lit)
IntegerValue LevelZeroUpperBound(IntegerVariable var) const
const Domain & InitialVariableDomain(IntegerVariable var) const
IntegerValue LevelZeroLowerBound(IntegerVariable var) const
Returns globally valid lower/upper bound on the given integer variable.
ABSL_MUST_USE_RESULT bool AddLiteralTerm(Literal lit, IntegerValue coeff=IntegerValue(1))
void AddTerm(IntegerVariable var, IntegerValue coeff)
void AddConstant(IntegerValue value)
Adds the corresponding term to the current linear expression.
void Clear()
Clears all added terms and constants. Keeps the original bounds.
LiteralIndex NegatedIndex() const
LiteralIndex Index() const
bool TryToLinearize(const AffineExpression &left, const AffineExpression &right, LinearConstraintBuilder *builder)
std::vector< LiteralValueValue > TryToDecompose(const AffineExpression &left, const AffineExpression &right)
void ProcessTernaryClause(absl::Span< const Literal > ternary_clause)
void ProcessTernaryExactlyOne(absl::Span< const Literal > ternary_exo)
void ProcessImplicationGraph(BinaryImplicationGraph *graph)
Utility function to process a bunch of implication all at once.
LiteralIndex GetProduct(Literal a, Literal b) const
Query function mainly used for testing.
bool ProductIsLinearizable(IntegerVariable a, IntegerVariable b) const
void ProcessConditionalEquality(Literal l, IntegerVariable x, IntegerVariable y)
void ProcessConditionalZero(Literal l, IntegerVariable p)
void ProcessTrailAtLevelOne()
void InitializeBooleanRLTCuts(const absl::flat_hash_map< IntegerVariable, glop::ColIndex > &lp_vars, const util_intops::StrongVector< IntegerVariable, double > &lp_values)
ProductDetector(Model *model)
void ProcessBinaryClause(absl::Span< const Literal > binary_clause)
int CurrentDecisionLevel() const
const std::vector< Decision > & Decisions() const
Simple class to add statistics by name and print them at the end.
void AddStats(absl::Span< const std::pair< std::string, int64_t > > stats)
Adds a bunch of stats, adding count for the same key together.
int CurrentDecisionLevel() const
const VariablesAssignment & Assignment() const
bool LiteralIsAssigned(Literal literal) const
void resize(size_type new_size)
constexpr IntegerValue kMaxIntegerValue(std::numeric_limits< IntegerValue::ValueType >::max() - 1)
std::vector< LiteralValueValue > TryToReconcileSize2Encodings(const AffineExpression &left, const AffineExpression &right, IntegerEncoder *integer_encoder)
const LiteralIndex kNoLiteralIndex(-1)
std::vector< IntegerVariable > NegationOf(const std::vector< IntegerVariable > &vars)
Returns the vector of the negated variables.
const IntegerVariable kNoIntegerVariable(-1)
IntegerVariable PositiveVariable(IntegerVariable i)
std::vector< LiteralValueValue > TryToReconcileEncodings(const AffineExpression &size2_affine, const AffineExpression &affine, absl::Span< const ValueLiteralPair > affine_var_encoding, bool put_affine_left_in_result, IntegerEncoder *integer_encoder)
bool VariableIsPositive(IntegerVariable i)
In SWIG mode, we don't want anything besides these top-level includes.
IntegerValue ValueAt(IntegerValue var_value) const
Returns the value of this affine expression given its variable value.
static IntegerLiteral GreaterOrEqual(IntegerVariable i, IntegerValue bound)