29#include "absl/functional/function_ref.h"
30#include "absl/log/check.h"
31#include "absl/log/log.h"
32#include "absl/strings/str_cat.h"
33#include "absl/synchronization/mutex.h"
34#include "absl/types/span.h"
55const int kNumInitialRestarts = 10;
64int MaxAllowedDiscrepancyPlusDepth(
int num_leaves) {
68 while (a < num_leaves) {
69 std::tie(a,
b) = std::make_pair(
b, a +
b);
79 const int num_leaves = params.shared_tree_open_leaves_per_worker() *
80 params.shared_tree_num_workers();
81 switch (params.shared_tree_split_strategy()) {
84 return MaxAllowedDiscrepancyPlusDepth(num_leaves) +
85 params.shared_tree_balance_tolerance();
87 return std::ceil(std::log2(num_leaves)) +
88 params.shared_tree_balance_tolerance();
99 return mapping->
Literal(proto_var_);
109 if (proto_var_ < 0) {
115std::optional<ProtoLiteral> ProtoLiteral::EncodeInteger(
118 const int model_var =
119 mapping->GetProtoVariableFromIntegerVariable(positive_var);
120 if (model_var == -1) {
124 literal.var == positive_var ? model_var :
NegatedRef(model_var),
126 DCHECK_EQ(result.DecodeInteger(mapping), literal);
127 DCHECK_EQ(result.Negated().DecodeInteger(mapping), literal.Negated());
133 const std::optional<ProtoLiteral> result =
EncodeLiteral(literal, mapping);
134 if (result.has_value())
return result;
137 auto result = EncodeInteger(int_lit, mapping);
138 if (result.has_value()) {
139 DCHECK_EQ(result->DecodeInteger(mapping), int_lit);
140 DCHECK_EQ(result->Negated().DecodeInteger(mapping), int_lit.Negated());
154 if (model_var == -1) {
173 IntegerValue objective_lb,
int node_id) {
174 CHECK_GT(node_id, 0);
175 decision_indexes_.push_back(literals_.size());
176 assigned_at_level_[decision] = decision_indexes_.size();
177 literals_.push_back(decision);
178 node_ids_.push_back(node_id);
179 implications_.push_back({});
180 if (!level_to_objective_lbs_.empty()) {
181 objective_lb = std::max(level_to_objective_lbs_.back(), objective_lb);
183 level_to_objective_lbs_.push_back(objective_lb);
188 DCHECK_LE(level, decision_indexes_.size());
189 DCHECK_LE(level, implications_.size());
192 assigned_at_level_[decision] = level - 1;
196 MutableImplications(level - 1).push_back(decision);
199 assigned_at_level_[implication] = level - 1;
201 MutableImplications(level - 1).push_back(implication);
206 implications_.erase(implications_.begin() + level - 1);
207 decision_indexes_.erase(decision_indexes_.begin() + level - 1);
208 level_to_objective_lbs_.erase(level_to_objective_lbs_.begin() + level - 1);
212 assigned_at_level_.clear();
213 for (
int level = 1; level <=
MaxLevel(); ++level) {
214 assigned_at_level_[
Decision(level)] = level;
216 std::vector<ProtoLiteral>& implications = MutableImplications(level);
217 for (
int i = 0;
i < implications.size(); ++
i) {
219 if (!assigned_at_level_.contains(implication)) {
220 implications[new_size++] = implication;
221 assigned_at_level_[implication] = level;
224 implications.resize(new_size);
229 decision_indexes_.clear();
231 level_to_objective_lbs_.clear();
233 target_phase_.clear();
234 assigned_at_level_.clear();
235 implications_.clear();
239 if (level == 0)
return;
240 level_to_objective_lbs_[level - 1] =
241 std::max(objective_lb, level_to_objective_lbs_[level - 1]);
245 DCHECK_LE(level, decision_indexes_.size());
246 return node_ids_[decision_indexes_[level - 1]];
250 DCHECK_LE(level, decision_indexes_.size());
251 int start = level == 0 ? 0 : decision_indexes_[level - 1];
252 int end = level == decision_indexes_.size() ? node_ids_.size()
253 : decision_indexes_[level];
254 return absl::MakeSpan(node_ids_.data() + start,
end - start);
258 if (level > implications_.size() || level <= 0) {
259 return absl::MakeSpan(literals_.data(), 0);
261 return absl::MakeSpan(implications_[level - 1]);
266 num_workers_(
std::max(0, params_.shared_tree_num_workers())),
267 max_path_depth_(MaxPossibleLeafDepth(params_)),
270 params_, &clause_id_generator_,
274 num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1),
276 params_.shared_tree_max_nodes_per_worker() >=
277 std::numeric_limits<int>::max() /
std::max(num_workers_, 1)
278 ?
std::numeric_limits<int>::max()
279 : num_workers_ * params_.shared_tree_max_nodes_per_worker()) {
283 .objective_lb = shared_response_manager_->GetInnerObjectiveLowerBound(),
284 .trail_info = std::make_unique<NodeTrailInfo>()});
285 unassigned_leaves_.push_back(&nodes_.back());
289 absl::MutexLock mutex_lock(mu_);
290 return nodes_.size();
294 absl::MutexLock mutex_lock(mu_);
295 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
296 if (!IsValid(path)) {
300 DCHECK(CheckLratInvariants());
302 DCHECK(to_close_.empty());
303 DCHECK(to_update_.empty());
305 for (
const auto& [node, level] : nodes) {
306 if (level == prev_level) {
311 Node* sibling = GetSibling(node);
313 if (lrat_proof_handler_ !=
nullptr) {
318 const std::vector<Literal> inferred_clause = ClosingClause(sibling);
319 std::vector<Literal> imported_clause;
320 std::vector<ClauseId> lrat_proof;
321 for (
int l = 1; l <= level + 1; ++l) {
323 const Literal decision = DecodeWithIdentityMapping(n->decision);
324 imported_clause.push_back(l <= level ? decision.
Negated() : decision);
325 if (n->implied_and_processed) {
326 lrat_proof.push_back(GetSibling(n)->closing_clause_id);
329 closing_clause_id = AddImportedAndInferredClauses(
330 imported_clause, inferred_clause, lrat_proof);
332 to_close_.emplace_back(sibling, closing_clause_id);
333 }
else if (level > 0 && node->objective_lb < path.
ObjectiveLb(level)) {
335 to_update_.push_back(node->parent);
337 if (level > 0 && !node->closed) {
338 NodeTrailInfo* trail_info = GetTrailInfo(node);
341 if (IsDecisionOfNodeOrAncestor(implication, node))
continue;
343 if (lrat_proof_handler_ !=
nullptr) {
349 const std::vector<Literal> inferred_clause =
350 ImplicationClause(node, implication);
351 std::vector<Literal> imported_clause;
352 std::vector<ClauseId> lrat_proof;
353 for (
int l = 1; l <= level; ++l) {
355 const Literal decision = DecodeWithIdentityMapping(n->decision);
356 imported_clause.push_back(decision.
Negated());
357 if (n->implied_and_processed) {
358 lrat_proof.push_back(GetSibling(n)->closing_clause_id);
361 imported_clause.push_back(DecodeWithIdentityMapping(implication));
362 implication_clause_id = AddImportedAndInferredClauses(
363 imported_clause, inferred_clause, lrat_proof);
365 auto it = trail_info->implications
366 .emplace(implication.proto_var(),
367 std::make_pair(implication.lb(),
368 implication_clause_id))
370 if (it->second.first < implication.lb()) {
371 it->second.first = implication.lb();
377 ProcessNodeChanges();
378 if (nodes.back().first->closed) {
385 if (num_leaves_assigned_since_restart_ >= num_workers_ &&
386 num_restarts_ < kNumInitialRestarts) {
392 AssignLeaf(path, nodes.back().first);
393 DCHECK(CheckLratInvariants());
399 decisions = decisions.subspan(0, max_path_depth_ - path.
MaxLevel());
400 if (decisions.empty())
return 0;
401 absl::MutexLock l(mu_);
402 for (
int i = 0;
i < decisions.size(); ++
i) {
403 if (!TrySplitTreeLockHeld(decisions[
i], path))
return i;
405 return decisions.size();
408bool SharedTreeManager::TrySplitTreeLockHeld(
ProtoLiteral decision,
410 if (!IsValid(path))
return false;
411 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
412 if (nodes.back().first->closed) {
413 VLOG(2) <<
"Cannot split closed node";
416 if (nodes.back().first->children[0] !=
nullptr) {
417 LOG_IF(WARNING, nodes.size() > 1)
418 <<
"Cannot resplit previously split node @ " << nodes.back().second
419 <<
"/" << nodes.size();
422 if (nodes_.size() + 2 > max_nodes_) {
423 VLOG(2) <<
"Too many nodes to accept split";
426 if (num_splits_wanted_ <= 0) {
427 VLOG(2) <<
"Enough splits for now";
430 for (
const auto& [node, level] : nodes) {
431 if (decision == node->decision || decision == node->decision.Negated()) {
432 VLOG(2) <<
"Cannot split on decision which is already in the tree";
436 if (params_.shared_tree_split_strategy() ==
438 params_.shared_tree_split_strategy() ==
441 for (
const auto& [node, level] : nodes) {
442 if (node->parent ==
nullptr || node->implied)
continue;
443 IntegerValue sibling_bound = GetSibling(node)->objective_lb;
444 discrepancy += (node->objective_lb == sibling_bound
445 ? node != node->parent->children[0]
446 : node->objective_lb > sibling_bound);
450 if (discrepancy + path.
MaxLevel() >= max_path_depth_) {
451 VLOG(2) <<
"Too high discrepancy to accept split";
454 }
else if (params_.shared_tree_split_strategy() ==
456 if (nodes.back().first->objective_lb > nodes.front().first->objective_lb) {
457 VLOG(2) <<
"Can only split nodes with minimum objective lb, "
458 << nodes.back().first->objective_lb <<
" > "
459 << nodes.front().first->objective_lb;
463 VLOG_EVERY_N(2, 10) << unassigned_leaves_.size() <<
" unassigned leaves, "
464 << nodes_.size() <<
" subtrees, " << num_splits_wanted_
466 Split(nodes, decision);
467 auto [new_leaf, level] = nodes.back();
468 path.
PushLevel(new_leaf->decision, new_leaf->objective_lb, new_leaf->id);
473 absl::MutexLock mutex_lock(mu_);
474 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
475 if (nodes.back().first->children[0] ==
nullptr &&
476 !nodes.back().first->closed && nodes.size() > 1) {
477 Node* leaf = nodes.back().first;
478 VLOG(2) <<
"Returning leaf to be replaced";
480 unassigned_leaves_.push_back(leaf);
483 while (!unassigned_leaves_.empty()) {
484 Node* leaf = unassigned_leaves_.front();
485 unassigned_leaves_.pop_front();
486 if (!leaf->closed && leaf->children[0] ==
nullptr) {
487 num_leaves_assigned_since_restart_ += 1;
488 AssignLeaf(path, leaf);
493 VLOG(2) <<
"Assigning root because no unassigned leaves are available";
498SharedTreeManager::NodeTrailInfo* SharedTreeManager::GetTrailInfo(Node* node) {
499 CHECK(node !=
nullptr && !node->closed);
500 while (node->trail_info ==
nullptr) {
503 CHECK_NE(node,
nullptr);
504 return node->trail_info.get();
507void SharedTreeManager::ClearTrailInfo(Node* node,
bool implications_only) {
508 if (node->trail_info ==
nullptr)
return;
509 if (lrat_proof_handler_ !=
nullptr) {
510 for (
const auto& [var, lb_and_clause] : node->trail_info->implications) {
511 lrat_proof_handler_->DeleteClause(lb_and_clause.second, {});
514 if (implications_only) {
515 node->trail_info->implications.clear();
517 node->trail_info.reset();
521SharedTreeManager::Node* SharedTreeManager::GetSibling(
const Node* node)
const {
522 if (node ==
nullptr || node->parent ==
nullptr)
return nullptr;
523 if (node->parent->children[0] != node) {
524 return node->parent->children[0];
526 return node->parent->children[1];
529void SharedTreeManager::Split(std::vector<std::pair<Node*, int>>& nodes,
531 const auto [parent, level] = nodes.back();
532 DCHECK(parent->children[0] ==
nullptr);
533 DCHECK(parent->children[1] ==
nullptr);
534 parent->children[0] = MakeSubtree(parent, lit);
535 parent->children[1] = MakeSubtree(parent, lit.Negated());
536 NodeTrailInfo* trail_info = GetTrailInfo(parent);
537 if (trail_info !=
nullptr) {
538 parent->children[0]->trail_info =
539 std::make_unique<NodeTrailInfo>(NodeTrailInfo{});
540 parent->children[1]->trail_info = std::make_unique<NodeTrailInfo>(
541 NodeTrailInfo{.phase = std::move(trail_info->phase)});
543 nodes.push_back(std::make_pair(parent->children[0], level + 1));
544 unassigned_leaves_.push_back(parent->children[1]);
545 --num_splits_wanted_;
548SharedTreeManager::Node* SharedTreeManager::MakeSubtree(Node* parent,
551 Node{.decision = decision,
552 .objective_lb = parent->objective_lb,
554 .id =
static_cast<int>(nodes_.size() + node_id_offset_)});
555 return &nodes_.back();
558void SharedTreeManager::ProcessNodeChanges() {
559 DCHECK(CheckLratInvariants());
560 int num_newly_closed = 0;
561 std::vector<Node*> newly_implied;
562 while (!to_close_.empty()) {
563 auto [node, closing_clause_id] = to_close_.back();
564 CHECK_NE(node,
nullptr);
565 to_close_.pop_back();
567 while (node !=
nullptr && !node->closed) {
571 node->closing_clause_id = closing_clause_id;
573 if (node->parent !=
nullptr) {
574 ClearTrailInfo(node);
578 num_splits_wanted_ += (node->children[0] ==
nullptr);
579 for (Node* child : node->children) {
580 if (child ==
nullptr || child->closed)
continue;
582 if (lrat_proof_handler_ !=
nullptr) {
586 child_closing_clause_id = clause_id_generator_.GetNextId();
587 lrat_proof_handler_->AddInferredClause(
588 child_closing_clause_id, ClosingClause(child),
589 {closing_clause_id},
true);
591 to_close_.emplace_back(child, child_closing_clause_id);
593 Node* sibling = GetSibling(node);
594 if (sibling !=
nullptr) {
595 sibling->implied =
true;
596 if (lrat_proof_handler_ !=
nullptr) {
597 newly_implied.push_back(sibling);
599 if (!sibling->closed) {
603 Node* parent = node->parent;
604 if (lrat_proof_handler_ !=
nullptr && parent !=
nullptr &&
606 closing_clause_id = clause_id_generator_.GetNextId();
609 lrat_proof_handler_->AddInferredClause(
610 closing_clause_id, ClosingClause(parent),
611 {node->closing_clause_id, sibling->closing_clause_id},
616 DCHECK(node ==
nullptr || node->closed);
617 if (node ==
nullptr) {
618 shared_response_manager_->NotifyThatImprovingProblemIsInfeasible(
620 }
else if (node->parent !=
nullptr) {
621 to_update_.push_back(node->parent);
624 if (num_newly_closed > 0) {
625 shared_response_manager_->LogMessageWithThrottling(
626 "Tree", absl::StrCat(
"closed:", num_closed_nodes_,
"/", nodes_.size(),
627 " unassigned:", unassigned_leaves_.size(),
628 " restarts:", num_restarts_));
630 DCHECK(CheckLratInvariants());
633 bool root_updated =
false;
634 while (!to_update_.empty()) {
635 Node* node = to_update_.back();
636 to_update_.pop_back();
638 while (node !=
nullptr && !node->closed) {
639 DCHECK(node->children[0] !=
nullptr);
640 DCHECK(node->children[1] !=
nullptr);
641 for (Node* child : node->children) {
642 if (child->implied) {
643 if (child->trail_info !=
nullptr) {
644 DCHECK(!child->implied_and_processed);
645 ProcessImpliedNode(child);
646 ClearTrailInfo(child);
648 child->implied_and_processed =
true;
651 IntegerValue child_bound = std::min(node->children[0]->objective_lb,
652 node->children[1]->objective_lb);
653 if (child_bound <= node->objective_lb)
break;
654 node->objective_lb = child_bound;
657 if (node ==
nullptr) root_updated =
true;
660 shared_response_manager_->UpdateInnerObjectiveBounds(
663 for (Node* node : newly_implied) {
664 if (!node->implied_and_processed) {
665 DCHECK_EQ(node->trail_info,
nullptr);
666 DCHECK_NE(lrat_proof_handler_,
nullptr);
667 ProcessImpliedNode(node);
668 node->implied_and_processed =
true;
672 ClearTrailInfo(&nodes_[0],
true);
673 DCHECK(CheckLratInvariants());
679void SharedTreeManager::ProcessImpliedNode(Node* node) {
680 CHECK(node->parent !=
nullptr);
681 Node* first_non_implied_ancestor = node->parent;
682 while (first_non_implied_ancestor->trail_info ==
nullptr) {
683 first_non_implied_ancestor = first_non_implied_ancestor->parent;
684 DCHECK_NE(first_non_implied_ancestor,
nullptr);
689 if (lrat_proof_handler_ ==
nullptr) {
690 first_non_implied_ancestor->trail_info->implications.merge(
691 node->trail_info->implications);
696 std::vector<ClauseId> clauses;
698 while (n->parent !=
nullptr) {
702 if (n->implied && !n->implied_and_processed) {
703 clauses.push_back(GetSibling(n)->closing_clause_id);
707 std::reverse(clauses.begin(), clauses.end());
709 if (node->trail_info !=
nullptr) {
710 for (
const auto& [var, lb_and_clause] : node->trail_info->implications) {
712 if (first_non_implied_ancestor->trail_info->implications.contains(var)) {
715 const auto [lb, clause_id] = lb_and_clause;
716 ClauseId new_clause_id = clause_id_generator_.GetNextId();
717 clauses.push_back(clause_id);
718 lrat_proof_handler_->AddInferredClause(
720 ImplicationClause(first_non_implied_ancestor, ProtoLiteral(var, lb),
724 first_non_implied_ancestor->trail_info->implications.insert(
725 {var, std::make_pair(lb, new_clause_id)});
728 UpdateLratClausesInSubtree(node, node, clauses);
736void SharedTreeManager::UpdateLratClausesInSubtree(
737 Node* node, Node* n, std::vector<ClauseId>& clauses) {
738 const bool implied_and_not_processed =
739 n->implied && !n->implied_and_processed;
740 if (implied_and_not_processed) {
744 clauses.push_back(GetSibling(n)->closing_clause_id);
748 ClauseId new_clause_id = clause_id_generator_.GetNextId();
749 clauses.push_back(n->closing_clause_id);
750 lrat_proof_handler_->AddInferredClause(
752 ClosingClause(n,
true), clauses,
755 lrat_proof_handler_->DeleteClause(n->closing_clause_id, {});
756 n->closing_clause_id = new_clause_id;
758 if (n != node && n->trail_info !=
nullptr) {
759 for (
auto& [var, lb_and_clause] : n->trail_info->implications) {
760 auto& [lb, clause_id] = lb_and_clause;
761 ClauseId new_clause_id = clause_id_generator_.GetNextId();
762 clauses.push_back(clause_id);
763 lrat_proof_handler_->AddInferredClause(
765 ImplicationClause(n, ProtoLiteral(var, lb),
768 lrat_proof_handler_->DeleteClause(clause_id, {});
769 clause_id = new_clause_id;
775 if (n == node || !(n->implied && n->trail_info !=
nullptr)) {
776 for (Node* child : n->children) {
777 if (child !=
nullptr && child->parent !=
nullptr) {
778 UpdateLratClausesInSubtree(node, child, clauses);
782 if (implied_and_not_processed) {
787SharedTreeManager::Node* SharedTreeManager::GetNode(
int id) {
788 const int index =
id - node_id_offset_;
790 CHECK_LT(index, nodes_.size());
791 return &nodes_[index];
794std::vector<std::pair<SharedTreeManager::Node*, int>>
795SharedTreeManager::GetAssignedNodes(
const ProtoTrail& path) {
796 std::vector<std::pair<Node*, int>> nodes({std::make_pair(&nodes_[0], 0)});
797 if (!IsValid(path)) {
802 for (
int i = 0;
i <= path.MaxLevel(); ++
i) {
803 for (
int id : path.NodeIds(
i)) {
804 const int index =
id - node_id_offset_;
805 CHECK_GE(index, 0) <<
" in path.NodeIds(" <<
i
806 <<
"), max_level=" << path.MaxLevel();
807 CHECK_LT(index, nodes_.size());
808 DCHECK_EQ(nodes.back().first, nodes_[index].parent);
809 nodes.push_back(std::make_pair(&nodes_[index],
i));
816 absl::MutexLock mutex_lock(mu_);
817 DCHECK(CheckLratInvariants());
818 const int node_id_to_close = path.
NodeIds(level).front();
819 if (node_id_to_close < node_id_offset_) {
823 Node* node = &nodes_[node_id_to_close - node_id_offset_];
824 VLOG(2) <<
"Closing subtree at level " << level;
825 DCHECK(to_close_.empty());
828 if (lrat_proof_handler_ !=
nullptr) {
834 const std::vector<Literal> inferred_clause = ClosingClause(node);
835 std::vector<Literal> imported_clause;
836 std::vector<ClauseId> lrat_proof;
837 for (
int l = 1; l <= level; ++l) {
839 const Literal decision = DecodeWithIdentityMapping(n->decision);
840 imported_clause.push_back(decision.
Negated());
841 if (n->implied_and_processed) {
842 lrat_proof.push_back(GetSibling(n)->closing_clause_id);
845 closing_clause_id = AddImportedAndInferredClauses(
846 imported_clause, inferred_clause, lrat_proof);
849 to_close_.emplace_back(node, closing_clause_id);
850 ProcessNodeChanges();
851 DCHECK(CheckLratInvariants());
854bool SharedTreeManager::IsDecisionOfNodeOrAncestor(
ProtoLiteral literal,
855 const Node* node)
const {
856 CHECK_NE(node,
nullptr);
857 while (node->parent !=
nullptr) {
858 if (literal == node->decision)
return true;
864std::vector<Literal> SharedTreeManager::ImplicationClause(
865 const Node* node, ProtoLiteral implied,
866 bool skip_unprocessed_implied_nodes)
const {
870 CHECK_NE(node,
nullptr);
871 std::vector<Literal> clause =
872 ClosingClause(node, skip_unprocessed_implied_nodes);
873 clause.push_back(DecodeWithIdentityMapping(implied));
877std::vector<Literal> SharedTreeManager::ClosingClause(
878 const Node* node,
bool skip_unprocessed_implied_nodes)
const {
882 CHECK_NE(node,
nullptr);
883 std::vector<Literal> clause;
884 while (node->parent !=
nullptr) {
891 const bool is_implied = node->implied && (node->implied_and_processed ||
892 skip_unprocessed_implied_nodes);
894 clause.push_back(DecodeWithIdentityMapping(node->decision).Negated());
902bool UnorderedSpansAreEqual(absl::Span<const Literal> a,
903 absl::Span<const Literal>
b) {
904 if (a.size() !=
b.size())
return false;
905 std::vector<Literal> sorted_a(a.begin(), a.end());
906 std::vector<Literal> sorted_b(
b.begin(),
b.end());
907 std::sort(sorted_a.begin(), sorted_a.end());
908 std::sort(sorted_b.begin(), sorted_b.end());
909 return sorted_a == sorted_b;
913ClauseId SharedTreeManager::AddImportedAndInferredClauses(
914 absl::Span<const Literal> imported_clause,
915 absl::Span<const Literal> inferred_clause,
916 std::vector<ClauseId>& lrat_proof) {
917 const ClauseId
id = clause_id_generator_.GetNextId();
918 lrat_proof_handler_->AddImportedClause(
id, imported_clause);
919 if (!lrat_proof.empty() ||
920 !UnorderedSpansAreEqual(inferred_clause, imported_clause)) {
921 lrat_proof.push_back(
id);
922 const ClauseId new_id = clause_id_generator_.GetNextId();
923 lrat_proof_handler_->AddInferredClause(new_id, inferred_clause, lrat_proof,
925 lrat_proof_handler_->DeleteClause(
id, {});
932void SharedTreeManager::AssignLeaf(
ProtoTrail& path, Node* leaf) {
934 std::vector<Node*> reversed_path;
935 while (leaf != &nodes_[0]) {
936 reversed_path.push_back(&nodes_[leaf->id - node_id_offset_]);
939 while (!reversed_path.empty()) {
940 Node* leaf = reversed_path.back();
941 reversed_path.pop_back();
942 path.PushLevel(leaf->decision, leaf->objective_lb, leaf->id);
944 path.SetLevelImplied(path.MaxLevel());
946 if (params_.shared_tree_worker_enable_trail_sharing() &&
947 leaf->trail_info !=
nullptr) {
948 for (
const auto& [var, lb_and_clause] : leaf->trail_info->implications) {
949 const auto [lb, clause_id] = lb_and_clause;
950 path.AddImplication(path.MaxLevel(), ProtoLiteral(var, lb));
956bool SharedTreeManager::IsValid(
const ProtoTrail& path)
const {
957 auto node_ids = path.NodeIds(path.MaxLevel());
958 if (node_ids.empty())
return true;
959 if (node_ids.back() < node_id_offset_)
return false;
963void SharedTreeManager::RestartLockHeld() {
964 node_id_offset_ += nodes_.size();
965 if (lrat_proof_handler_ !=
nullptr) {
966 for (
const Node& node : nodes_) {
968 lrat_proof_handler_->DeleteClause(node.closing_clause_id, {});
973 nodes_[0].id = node_id_offset_;
974 nodes_[0].children = {
nullptr,
nullptr};
975 unassigned_leaves_.clear();
976 DCHECK(to_close_.empty());
977 DCHECK(to_update_.empty());
979 num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1;
980 num_closed_nodes_ = 0;
982 num_leaves_assigned_since_restart_ = 0;
986 absl::MutexLock l(mu_);
987 if (lrat_proof_handler_ !=
nullptr) {
988 lrat_proof_handler_->Close(
false);
989 lrat_proof_handler_.reset();
993std::string SharedTreeManager::ShortStatus()
const {
994 return absl::StrCat(
"shared_tree_manager(r=", num_restarts_,
995 " n=", nodes_.size(),
")");
999void CheckEqual(absl::Span<const Literal> a, absl::Span<const Literal>
b) {
1000 std::vector<Literal> sorted_a(a.begin(), a.end());
1001 std::vector<Literal> sorted_b(
b.begin(),
b.end());
1002 std::sort(sorted_a.begin(), sorted_a.end());
1003 std::sort(sorted_b.begin(), sorted_b.end());
1004 CHECK_EQ(sorted_a, sorted_b);
1008bool SharedTreeManager::CheckLratInvariants()
const {
1009 if (lrat_proof_handler_ !=
nullptr &&
1010 lrat_proof_handler_->lrat_check_enabled()) {
1011 for (
const Node& node : nodes_) {
1012 if (node.parent ==
nullptr)
continue;
1015 lrat_proof_handler_->GetLratClauseForDebug(node.closing_clause_id),
1016 ClosingClause(&node));
1018 if (node.trail_info !=
nullptr) {
1019 for (
const auto& [var, lb_and_clause] : node.trail_info->implications) {
1020 const auto [lb, clause_id] = lb_and_clause;
1021 CheckEqual(lrat_proof_handler_->GetLratClauseForDebug(clause_id),
1022 ImplicationClause(&node, ProtoLiteral(var, lb)));
1033 time_limit_(model->GetOrCreate<
TimeLimit>()),
1036 sat_solver_(model->GetOrCreate<
SatSolver>()),
1037 trail_(model->GetOrCreate<
Trail>()),
1052 assigned_tree_lbds_(8) {}
1054std::vector<Literal>& SharedTreeWorker::DecisionReason(
int level) {
1055 CHECK_LE(level, assigned_tree_decisions_.size());
1057 for (
int i = 0;
i < level; ++
i) {
1058 reason_.push_back(assigned_tree_decisions_[
i].Negated());
1063bool SharedTreeWorker::AddDecisionImplication(Literal lit,
int level,
1064 ClauseId clause_id) {
1068 absl::Span<const Literal> reason = DecisionReason(level);
1070 VLOG(2) <<
"Closing subtree via impl at " << level + 1
1071 <<
" assigned=" << assigned_tree_.
MaxLevel();
1072 if (lrat_proof_handler_ !=
nullptr) {
1075 const ClauseId closing_clause_id = clause_id_generator_->
GetNextId();
1076 std::vector<ClauseId> clause_ids;
1078 clause_ids.push_back(clause_id);
1079 lrat_proof_handler_->AddInferredClause(closing_clause_id,
1080 DecisionReason(level), clause_ids,
1082 lrat_proof_handler_->DeleteClause(closing_clause_id, {});
1084 trail_->MutableConflict()->assign(reason.begin(), reason.end());
1085 manager_->CloseTree(assigned_tree_, level);
1086 assigned_tree_decisions_.clear();
1089 VLOG(2) <<
"Learned shared clause";
1090 trail_->GetEmptyVectorToStoreReason()->assign(reason.begin(), reason.end());
1091 return trail_->EnqueueWithStoredReason(clause_id, lit);
1094bool SharedTreeWorker::AddImplications() {
1095 const int level = sat_solver_->CurrentDecisionLevel();
1097 if (level == 0)
return false;
1098 if (level > assigned_tree_.MaxLevel()) {
1101 rev_num_processed_implications_.resize(level + 1, 0);
1102 auto& num_processed_implications = rev_num_processed_implications_[level];
1103 reversible_int_repository_->SaveState(&num_processed_implications);
1104 absl::Span<const std::pair<Literal, ClauseId>> implied_literals =
1105 absl::MakeConstSpan(assigned_tree_implications_[level - 1])
1106 .subspan(num_processed_implications);
1107 bool added_clause =
false;
1108 for (
const auto& [implied, clause_id] : implied_literals) {
1109 ++num_processed_implications;
1110 if (sat_solver_->Assignment().LiteralIsTrue(implied))
continue;
1111 added_clause =
true;
1112 if (!AddDecisionImplication(implied, level, clause_id))
return true;
1114 if (objective_ !=
nullptr &&
1116 const IntegerValue obj_lb =
1117 integer_trail_->LowerBound(objective_->objective_var);
1118 assigned_tree_.SetObjectiveLb(level, obj_lb);
1119 const Literal obj_lit =
1121 objective_->objective_var, assigned_tree_.ObjectiveLb(level)));
1122 if (!sat_solver_->Assignment().LiteralIsTrue(obj_lit)) {
1123 AddDecisionImplication(obj_lit, level,
kNoClauseId);
1127 DCHECK(CheckLratInvariants());
1128 return added_clause;
1131void SharedTreeWorker::ClearAssignedTreeDecisionsAndImplications() {
1139 if (lrat_proof_handler_ !=
nullptr) {
1140 for (
const auto& implications : assigned_tree_implications_) {
1141 for (
const auto& [literal, clause_id] : implications) {
1142 lrat_proof_handler_->DeleteClause(clause_id, {});
1146 assigned_tree_decisions_.clear();
1147 assigned_tree_implications_.clear();
1150bool SharedTreeWorker::SyncWithLocalTrail() {
1151 DCHECK(CheckLratInvariants());
1152 std::vector<int> new_implication_trail_indices;
1154 if (lrat_proof_handler_ !=
nullptr) {
1155 trail_implication_clauses_.resize(reversible_trail_index_,
kNoClauseId);
1157 if (!sat_solver_->FinishPropagation())
return false;
1160 if (AddImplications())
continue;
1162 if (!helper_->BeforeTakingDecision())
return false;
1163 const int level = sat_solver_->CurrentDecisionLevel();
1164 if (parameters_->shared_tree_worker_enable_trail_sharing() && level > 0 &&
1165 level <= assigned_tree_.MaxLevel() &&
1166 reversible_trail_index_ < trail_->
Index()) {
1167 const int binary_propagator_id = binary_propagator_->PropagatorId();
1169 reversible_int_repository_->SaveState(&reversible_trail_index_);
1170 new_implication_trail_indices.clear();
1171 for (
int i = trail_->Index() - 1;
i >= reversible_trail_index_; --
i) {
1172 const Literal lit = (*trail_)[
i];
1173 const int assignment_type = trail_->AssignmentType(lit.Variable());
1177 if (assignment_type == binary_propagator_id)
continue;
1178 std::optional<ProtoLiteral> encoded = EncodeDecision(lit);
1179 if (!encoded.has_value())
continue;
1180 if (assigned_tree_.AddImplication(level, *encoded) &&
1181 lrat_proof_handler_ !=
nullptr) {
1182 new_implication_trail_indices.push_back(
i);
1189 if (lrat_proof_handler_ !=
nullptr) {
1190 trail_implication_clauses_.resize(trail_->Index(),
kNoClauseId);
1191 for (
int i = new_implication_trail_indices.size() - 1;
i >= 0; --
i) {
1192 const int new_trail_index = new_implication_trail_indices[
i];
1193 const Literal lit = (*trail_)[new_trail_index];
1194 trail_implication_clauses_[new_trail_index] =
1195 AddLratClauseAndProofForImplication(
1196 lit, level, [&](
int ,
int trail_index) {
1197 return trail_implication_clauses_[trail_index];
1201 reversible_trail_index_ = trail_->Index();
1203 if (level >= assigned_tree_.MaxLevel())
break;
1205 const Literal next_decision = assigned_tree_decisions_[level];
1206 if (!sat_solver_->Assignment().LiteralIsAssigned(next_decision))
break;
1207 if (sat_solver_->Assignment().LiteralIsFalse(next_decision)) {
1209 VLOG(2) <<
"Closing subtree at " << level + 1
1210 <<
" assigned=" << assigned_tree_.MaxLevel();
1214 const ClauseId clause_id =
1215 AddLratClauseAndProofForImplication(next_decision.Negated(), level);
1216 manager_->CloseTree(assigned_tree_, level + 1);
1217 if (lrat_proof_handler_ !=
nullptr) {
1218 lrat_proof_handler_->DeleteClause(clause_id, {});
1220 ClearAssignedTreeDecisionsAndImplications();
1221 sat_solver_->Backtrack(0);
1224 if (lrat_proof_handler_ !=
nullptr) {
1229 const ClauseId clause_id =
1230 AddLratClauseAndProofForImplication(next_decision, level);
1231 std::vector<Literal> implication;
1232 for (
int i = 0;
i < level; ++
i) {
1233 implication.push_back(assigned_tree_decisions_[
i].Negated());
1235 for (
int l = level; l < assigned_tree_decisions_.size(); ++l) {
1237 implication.push_back(assigned_tree_decisions_[l].Negated());
1239 for (
auto& [lit,
id] : assigned_tree_implications_[l]) {
1240 const ClauseId old_id = id;
1241 id = clause_id_generator_->GetNextId();
1242 implication.push_back(lit);
1243 lrat_proof_handler_->AddInferredClause(
1244 id, implication, {clause_id, old_id},
true);
1245 lrat_proof_handler_->DeleteClause(old_id, {});
1246 implication.pop_back();
1249 lrat_proof_handler_->DeleteClause(clause_id, {});
1251 assigned_tree_.SetLevelImplied(level + 1);
1253 assigned_tree_implications_[level - 1].insert(
1254 assigned_tree_implications_[level - 1].
end(),
1255 assigned_tree_implications_[level].
begin(),
1256 assigned_tree_implications_[level].
end());
1258 assigned_tree_implications_.erase(assigned_tree_implications_.begin() +
1260 assigned_tree_decisions_.erase(assigned_tree_decisions_.begin() + level);
1263 DCHECK(CheckLratInvariants());
1267ClauseId SharedTreeWorker::AddLratClauseAndProofForImplication(
1269 std::optional<absl::FunctionRef<ClauseId(
int,
int)>> root_literals) {
1270 if (lrat_proof_handler_ ==
nullptr)
return kNoClauseId;
1272 CHECK_LE(level, assigned_tree_decisions_.size());
1273 const ClauseId clause_id = clause_id_generator_->GetNextId();
1274 std::vector<Literal>& implication = DecisionReason(level);
1275 implication.push_back(literal);
1276 std::vector<ClauseId> clause_ids;
1277 clause_manager_->AppendClauseIdsFixing(
1279 lrat_proof_handler_->AddInferredClause(clause_id, implication, clause_ids,
1284ClauseId SharedTreeWorker::ImportLratClauseForImplication(
Literal literal,
1286 if (lrat_proof_handler_ ==
nullptr)
return kNoClauseId;
1288 CHECK_LE(level, assigned_tree_decisions_.size());
1289 const ClauseId clause_id = clause_id_generator_->GetNextId();
1290 std::vector<Literal>& implication = DecisionReason(level);
1291 implication.push_back(literal);
1292 lrat_proof_handler_->AddImportedClause(clause_id, implication);
1296bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) {
1297 const auto& decision_policy =
1298 heuristics_->decision_policies[heuristics_->policy_index];
1299 const int next_level = sat_solver_->CurrentDecisionLevel() + 1;
1300 CHECK_EQ(assigned_tree_decisions_.size(), assigned_tree_.MaxLevel());
1301 if (next_level <= assigned_tree_.MaxLevel()) {
1302 VLOG(2) <<
"Following shared trail depth=" << next_level <<
" "
1303 << parameters_->name();
1304 const Literal decision = assigned_tree_decisions_[next_level - 1];
1305 CHECK(!sat_solver_->Assignment().LiteralIsFalse(decision))
1306 <<
" at depth " << next_level <<
" " << parameters_->name();
1307 CHECK(!sat_solver_->Assignment().LiteralIsTrue(decision));
1308 *decision_index = decision.Index();
1311 return helper_->GetDecision(decision_policy, decision_index);
1314void SharedTreeWorker::MaybeProposeSplits() {
1315 if (time_limit_->GetElapsedDeterministicTime() <= next_split_dtime_) {
1318 next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() +
1319 parameters_->shared_tree_split_min_dtime();
1320 tmp_splits_.clear();
1321 const int max_split_level =
1322 std::min<int>(trail_->CurrentDecisionLevel(), manager_->MaxPathDepth());
1323 for (
int i = assigned_tree_.MaxLevel();
i < max_split_level; ++
i) {
1324 const Literal split_decision = trail_->Decisions()[
i].literal;
1325 const std::optional<ProtoLiteral> encoded = EncodeDecision(split_decision);
1326 if (!encoded.has_value())
break;
1327 tmp_splits_.push_back(*encoded);
1329 const int splits_accepted =
1330 manager_->TrySplitTree(tmp_splits_, assigned_tree_);
1331 for (
int i = 0;
i < splits_accepted; ++
i) {
1332 assigned_tree_decisions_.push_back(DecodeDecision(tmp_splits_[
i]));
1333 assigned_tree_implications_.push_back({});
1337bool SharedTreeWorker::ShouldReplaceSubtree() {
1339 if (assigned_tree_.MaxLevel() == 0)
return true;
1340 if (restart_policy_->NumRestarts() <
1341 parameters_->shared_tree_worker_min_restarts_per_subtree() ||
1342 time_limit_->GetElapsedDeterministicTime() <
1343 earliest_replacement_dtime_) {
1346 return assigned_tree_lbds_.WindowAverage() <
1347 restart_policy_->LbdAverageSinceReset();
1350bool SharedTreeWorker::SyncWithSharedTree() {
1351 DCHECK_EQ(trail_->CurrentDecisionLevel(), 0);
1352 DCHECK(CheckLratInvariants());
1353 manager_->SyncTree(assigned_tree_);
1354 assigned_tree_.NormalizeImplications();
1355 if (ShouldReplaceSubtree()) {
1357 VLOG(2) << parameters_->name() <<
" acquiring tree #" << num_trees_
1358 <<
" after " << restart_policy_->NumRestarts() <<
" restarts"
1359 <<
" prev depth: " << assigned_tree_.MaxLevel()
1360 <<
" target: " << assigned_tree_lbds_.WindowAverage()
1361 <<
" lbd: " << restart_policy_->LbdAverageSinceReset();
1362 if (parameters_->shared_tree_worker_enable_phase_sharing() &&
1365 FinishedMinRestarts() &&
1366 !decision_policy_->GetBestPartialAssignment().empty()) {
1367 assigned_tree_.ClearTargetPhase();
1368 for (Literal lit : decision_policy_->GetBestPartialAssignment()) {
1370 if (trail_->Assignment().LiteralIsAssigned(lit))
continue;
1373 if (trail_->Info(lit.Variable()).level <= assigned_tree_.MaxLevel()) {
1379 if (!encoded.has_value())
continue;
1380 if (!assigned_tree_.AddPhase(*encoded))
break;
1383 manager_->ReplaceTree(assigned_tree_);
1384 assigned_tree_.NormalizeImplications();
1385 assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset());
1386 restart_policy_->Reset();
1387 earliest_replacement_dtime_ = 0;
1388 if (assigned_tree_.MaxLevel() > 0) {
1389 next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() +
1390 parameters_->shared_tree_split_min_dtime();
1392 if (parameters_->shared_tree_worker_enable_phase_sharing()) {
1393 VLOG(2) <<
"Importing phase of length: "
1394 << assigned_tree_.TargetPhase().size();
1395 decision_policy_->ClearBestPartialAssignment();
1396 for (
const ProtoLiteral& lit : assigned_tree_.TargetPhase()) {
1397 decision_policy_->SetTargetPolarityIfUnassigned(DecodeDecision(lit));
1399 decision_policy_->ResetActivitiesToFollowBestPartialAssignment();
1405 decision_policy_->ClearBestPartialAssignment();
1411 if (FinishedMinRestarts() && earliest_replacement_dtime_ >=
1412 time_limit_->GetElapsedDeterministicTime()) {
1413 earliest_replacement_dtime_ =
1414 time_limit_->GetElapsedDeterministicTime() + 1;
1416 assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset());
1418 VLOG(2) <<
"Assigned level: " << assigned_tree_.MaxLevel() <<
" "
1419 << parameters_->name();
1420 ClearAssignedTreeDecisionsAndImplications();
1421 for (
int level = 1; level <= assigned_tree_.MaxLevel(); ++level) {
1422 assigned_tree_decisions_.push_back(
1423 DecodeDecision(assigned_tree_.Decision(level)));
1424 std::vector<std::pair<Literal, ClauseId>> implications;
1425 for (
const ProtoLiteral& impl : assigned_tree_.Implications(level)) {
1426 const Literal lit = DecodeDecision(impl);
1427 implications.emplace_back(lit,
1428 ImportLratClauseForImplication(lit, level));
1430 assigned_tree_implications_.push_back(std::move(implications));
1432 DCHECK(CheckLratInvariants());
1437 const std::function<
void()>& feasible_solution_observer) {
1442 sat_solver_->Backtrack(0);
1443 encoder_->GetTrueLiteral();
1444 encoder_->GetFalseLiteral();
1445 level_zero_callbacks_->callbacks.push_back(
1446 [
this]() {
return SyncWithSharedTree(); });
1447 const bool has_objective =
1449 while (!time_limit_->LimitReached()) {
1450 if (!sat_solver_->FinishPropagation()) {
1451 return sat_solver_->UnsatStatus();
1453 if (heuristics_->restart_policies[heuristics_->policy_index]()) {
1454 heuristics_->policy_index = restart_policy_->NumRestarts() %
1455 heuristics_->decision_policies.size();
1456 sat_solver_->Backtrack(0);
1458 if (!SyncWithLocalTrail())
return sat_solver_->UnsatStatus();
1459 LiteralIndex decision_index;
1460 if (!NextDecision(&decision_index))
continue;
1463 feasible_solution_observer();
1465 const IntegerValue objective =
1466 integer_trail_->LowerBound(objective_->objective_var);
1467 sat_solver_->Backtrack(0);
1468 if (!integer_trail_->Enqueue(
1477 const Literal decision(decision_index);
1478 CHECK(!sat_solver_->Assignment().LiteralIsFalse(decision));
1479 CHECK(!sat_solver_->Assignment().LiteralIsTrue(decision));
1482 if (!helper_->TakeDecision(
1483 decision, lrat_proof_handler_ ==
nullptr)) {
1484 return sat_solver_->UnsatStatus();
1486 MaybeProposeSplits();
1493 return lit.
Decode(mapping_, encoder_);
1496std::optional<ProtoLiteral> SharedTreeWorker::EncodeDecision(Literal decision) {
1500bool SharedTreeWorker::CheckLratInvariants() {
1501 if (lrat_proof_handler_ !=
nullptr &&
1503 for (
int level = 0; level < assigned_tree_decisions_.size(); ++level) {
1504 for (
auto& [lit,
id] : assigned_tree_implications_[level]) {
1505 std::vector<Literal>& expected = DecisionReason(level + 1);
1506 expected.push_back(lit);
void AppendClauseIdsFixing(absl::Span< const Literal > literals, std::vector< ClauseId > *clause_ids, LiteralIndex decision=kNoLiteralIndex, std::optional< absl::FunctionRef< ClauseId(int, int)> > root_literals={})
sat::Literal Literal(int ref) const
bool IsBoolean(int ref) const
bool IsInteger(int ref) const
int GetProtoVariableFromBooleanVariable(BooleanVariable var) const
int NumProtoVariables() const
IntegerVariable Integer(int ref) const
const InlinedIntegerLiteralVector & GetIntegerLiterals(Literal lit) const
Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit)
LiteralIndex Index() const
BooleanVariable Variable() const
absl::Span< const Literal > GetLratClauseForDebug(ClauseId id) const
bool lrat_check_enabled() const
static std::optional< ProtoLiteral > Encode(Literal, CpModelMapping *, IntegerEncoder *)
Literal Decode(CpModelMapping *, IntegerEncoder *) const
static std::optional< ProtoLiteral > EncodeLiteral(Literal, CpModelMapping *)
std::vector< ProtoLiteral > TakeTargetPhase()
void SetTargetPhase(std::vector< ProtoLiteral > phase)
ProtoLiteral Decision(int level) const
void SetObjectiveLb(int level, IntegerValue objective_lb)
void NormalizeImplications()
int DecisionNodeId(int level) const
absl::Span< const ProtoLiteral > Implications(int level) const
absl::Span< const int > NodeIds(int level) const
IntegerValue ObjectiveLb(int level) const
void PushLevel(const ProtoLiteral &decision, IntegerValue objective_lb, int node_id)
void SetLevelImplied(int level)
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_DISCREPANCY
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_BALANCED_TREE
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_AUTO
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_OBJECTIVE_LB
const VariablesAssignment & Assignment() const
void CloseTree(ProtoTrail &path, int level)
SharedTreeManager(Model *model)
void ReplaceTree(ProtoTrail &path)
bool SyncTree(ProtoTrail &path) ABSL_LOCKS_EXCLUDED(mu_)
int NumNodes() const ABSL_LOCKS_EXCLUDED(mu_)
int TrySplitTree(absl::Span< const ProtoLiteral > decisions, ProtoTrail &path) ABSL_LOCKS_EXCLUDED(mu_)
SharedTreeWorker(Model *model)
SatSolver::Status Search(const std::function< void()> &feasible_solution_observer)
bool LiteralIsFalse(Literal literal) const
bool LiteralIsTrue(Literal literal) const
constexpr IntegerValue kMaxIntegerValue(std::numeric_limits< IntegerValue::ValueType >::max() - 1)
bool RefIsPositive(int ref)
const LiteralIndex kNoLiteralIndex(-1)
const IntegerVariable kNoIntegerVariable(-1)
IntegerVariable PositiveVariable(IntegerVariable i)
constexpr ClauseId kNoClauseId(0)
ClosedInterval::Iterator end(ClosedInterval interval)
ClosedInterval::Iterator begin(ClosedInterval interval)
static constexpr int kSearchDecision
static IntegerLiteral GreaterOrEqual(IntegerVariable i, IntegerValue bound)
static IntegerLiteral LowerOrEqual(IntegerVariable i, IntegerValue bound)