Google OR-Tools v9.15
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
work_assignment.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
15
16#include <algorithm>
17#include <array>
18#include <cmath>
19#include <deque>
20#include <functional>
21#include <limits>
22#include <memory>
23#include <optional>
24#include <string>
25#include <tuple>
26#include <utility>
27#include <vector>
28
29#include "absl/functional/function_ref.h"
30#include "absl/log/check.h"
31#include "absl/log/log.h"
32#include "absl/strings/str_cat.h"
33#include "absl/synchronization/mutex.h"
34#include "absl/types/span.h"
35#include "ortools/sat/clause.h"
38#include "ortools/sat/integer.h"
42#include "ortools/sat/model.h"
43#include "ortools/sat/restart.h"
49#include "ortools/sat/util.h"
52
54namespace {
55const int kNumInitialRestarts = 10;
56
57// If you build a tree by expanding the nodes with minimal depth+discrepancy,
58// the number of leaves when all nodes less than a given value have been split
59// follows the fibonacci sequence:
60// num_leaves(0) := 1;
61// num_leaves(1) := 2;
62// num_leaves(n) := num_leaves(n-1) + num_leaves(n-2)
63// This function returns f(n) := min({i | num_leaves(i) >= n})
64int MaxAllowedDiscrepancyPlusDepth(int num_leaves) {
65 int i = 0;
66 int a = 1;
67 int b = 2;
68 while (a < num_leaves) {
69 std::tie(a, b) = std::make_pair(b, a + b);
70 ++i;
71 }
72 return i;
73}
74
75// Returns the maximum depth of any leaf in the shared tree.
76// This is an upper bound that can be computed without needing a lock on the
77// shared tree.
78int MaxPossibleLeafDepth(const SatParameters& params) {
79 const int num_leaves = params.shared_tree_open_leaves_per_worker() *
80 params.shared_tree_num_workers();
81 switch (params.shared_tree_split_strategy()) {
84 return MaxAllowedDiscrepancyPlusDepth(num_leaves) +
85 params.shared_tree_balance_tolerance();
87 return std::ceil(std::log2(num_leaves)) +
88 params.shared_tree_balance_tolerance();
89 default:
90 return num_leaves;
91 }
92}
93} // namespace
94
96 IntegerEncoder* encoder) const {
97 DCHECK_LT(proto_var_, mapping->NumProtoVariables());
98 if (mapping->IsBoolean(proto_var_)) {
99 return mapping->Literal(proto_var_);
100 }
101 return encoder->GetOrCreateAssociatedLiteral(DecodeInteger(mapping));
102}
103
104IntegerLiteral ProtoLiteral::DecodeInteger(CpModelMapping* mapping) const {
105 const int positive_var = PositiveRef(proto_var_);
106 if (!mapping->IsInteger(positive_var)) {
107 return IntegerLiteral();
108 }
109 if (proto_var_ < 0) {
110 return IntegerLiteral::LowerOrEqual(mapping->Integer(positive_var), -lb_);
111 }
112 return IntegerLiteral::GreaterOrEqual(mapping->Integer(positive_var), lb_);
113}
114
115std::optional<ProtoLiteral> ProtoLiteral::EncodeInteger(
116 IntegerLiteral literal, CpModelMapping* mapping) {
117 IntegerVariable positive_var = PositiveVariable(literal.var);
118 const int model_var =
119 mapping->GetProtoVariableFromIntegerVariable(positive_var);
120 if (model_var == -1) {
121 return std::nullopt;
122 }
123 ProtoLiteral result{
124 literal.var == positive_var ? model_var : NegatedRef(model_var),
125 literal.bound};
126 DCHECK_EQ(result.DecodeInteger(mapping), literal);
127 DCHECK_EQ(result.Negated().DecodeInteger(mapping), literal.Negated());
128 return result;
129}
130std::optional<ProtoLiteral> ProtoLiteral::Encode(Literal literal,
131 CpModelMapping* mapping,
132 IntegerEncoder* encoder) {
133 const std::optional<ProtoLiteral> result = EncodeLiteral(literal, mapping);
134 if (result.has_value()) return result;
135
136 for (auto int_lit : encoder->GetIntegerLiterals(literal)) {
137 auto result = EncodeInteger(int_lit, mapping);
138 if (result.has_value()) {
139 DCHECK_EQ(result->DecodeInteger(mapping), int_lit);
140 DCHECK_EQ(result->Negated().DecodeInteger(mapping), int_lit.Negated());
141 return result;
142 }
143 }
144 return std::nullopt;
145}
146
147std::optional<ProtoLiteral> ProtoLiteral::EncodeLiteral(
148 Literal literal, CpModelMapping* mapping) {
149 if (literal.Index() == kNoLiteralIndex) {
150 return std::nullopt;
151 }
152 int model_var =
154 if (model_var == -1) {
155 return std::nullopt;
156 }
157 DCHECK(mapping->IsBoolean(model_var));
158 ProtoLiteral result{literal.IsPositive() ? model_var : NegatedRef(model_var),
159 literal.IsPositive() ? 1 : 0};
160 return result;
161}
162
163namespace {
164Literal DecodeWithIdentityMapping(const ProtoLiteral& literal) {
165 const int ref = literal.proto_var();
166 return Literal(BooleanVariable(PositiveRef(ref)), RefIsPositive(ref));
167}
168} // namespace
169
170ProtoTrail::ProtoTrail() { target_phase_.reserve(kMaxPhaseSize); }
171
173 IntegerValue objective_lb, int node_id) {
174 CHECK_GT(node_id, 0);
175 decision_indexes_.push_back(literals_.size());
176 assigned_at_level_[decision] = decision_indexes_.size();
177 literals_.push_back(decision);
178 node_ids_.push_back(node_id);
179 implications_.push_back({});
180 if (!level_to_objective_lbs_.empty()) {
181 objective_lb = std::max(level_to_objective_lbs_.back(), objective_lb);
182 }
183 level_to_objective_lbs_.push_back(objective_lb);
184}
185
187 DCHECK_GE(level, 1);
188 DCHECK_LE(level, decision_indexes_.size());
189 DCHECK_LE(level, implications_.size());
190 SetObjectiveLb(level - 1, ObjectiveLb(level));
191 const ProtoLiteral decision = Decision(level);
192 assigned_at_level_[decision] = level - 1;
193 // We don't store implications for level 0, so only move implications up to
194 // the parent if we are removing level 2 or greater.
195 if (level >= 2) {
196 MutableImplications(level - 1).push_back(decision);
197 }
198 for (const ProtoLiteral& implication : Implications(level)) {
199 assigned_at_level_[implication] = level - 1;
200 if (level >= 2) {
201 MutableImplications(level - 1).push_back(implication);
202 }
203 }
204 // implications_[level-1] stores the implications for level, which are now
205 // stored in the parent's implications, so we can delete them.
206 implications_.erase(implications_.begin() + level - 1);
207 decision_indexes_.erase(decision_indexes_.begin() + level - 1);
208 level_to_objective_lbs_.erase(level_to_objective_lbs_.begin() + level - 1);
209}
210
212 assigned_at_level_.clear();
213 for (int level = 1; level <= MaxLevel(); ++level) {
214 assigned_at_level_[Decision(level)] = level;
215 int new_size = 0;
216 std::vector<ProtoLiteral>& implications = MutableImplications(level);
217 for (int i = 0; i < implications.size(); ++i) {
218 const ProtoLiteral& implication = implications[i];
219 if (!assigned_at_level_.contains(implication)) {
220 implications[new_size++] = implication;
221 assigned_at_level_[implication] = level;
222 }
223 }
224 implications.resize(new_size);
225 }
226}
227
229 decision_indexes_.clear();
230 literals_.clear();
231 level_to_objective_lbs_.clear();
232 node_ids_.clear();
233 target_phase_.clear();
234 assigned_at_level_.clear();
235 implications_.clear();
236}
237
238void ProtoTrail::SetObjectiveLb(int level, IntegerValue objective_lb) {
239 if (level == 0) return;
240 level_to_objective_lbs_[level - 1] =
241 std::max(objective_lb, level_to_objective_lbs_[level - 1]);
242}
243
244int ProtoTrail::DecisionNodeId(int level) const {
245 DCHECK_LE(level, decision_indexes_.size());
246 return node_ids_[decision_indexes_[level - 1]];
247}
248
249absl::Span<const int> ProtoTrail::NodeIds(int level) const {
250 DCHECK_LE(level, decision_indexes_.size());
251 int start = level == 0 ? 0 : decision_indexes_[level - 1];
252 int end = level == decision_indexes_.size() ? node_ids_.size()
253 : decision_indexes_[level];
254 return absl::MakeSpan(node_ids_.data() + start, end - start);
255}
256
257absl::Span<const ProtoLiteral> ProtoTrail::Implications(int level) const {
258 if (level > implications_.size() || level <= 0) {
259 return absl::MakeSpan(literals_.data(), 0);
260 }
261 return absl::MakeSpan(implications_[level - 1]);
262}
263
265 : params_(*model->GetOrCreate<SatParameters>()),
266 num_workers_(std::max(0, params_.shared_tree_num_workers())),
267 max_path_depth_(MaxPossibleLeafDepth(params_)),
268 shared_response_manager_(model->GetOrCreate<SharedResponseManager>()),
269 lrat_proof_handler_(LratProofHandler::MaybeCreate(
270 params_, &clause_id_generator_,
271 model->GetOrCreate<SharedLratProofStatus>(),
272 model->GetOrCreate<SharedStatistics>())),
273 num_splits_wanted_(
274 num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1),
275 max_nodes_(
276 params_.shared_tree_max_nodes_per_worker() >=
277 std::numeric_limits<int>::max() / std::max(num_workers_, 1)
278 ? std::numeric_limits<int>::max()
279 : num_workers_ * params_.shared_tree_max_nodes_per_worker()) {
280 // Create the root node with a fake decision.
281 nodes_.push_back(
282 {.decision = ProtoLiteral(),
283 .objective_lb = shared_response_manager_->GetInnerObjectiveLowerBound(),
284 .trail_info = std::make_unique<NodeTrailInfo>()});
285 unassigned_leaves_.push_back(&nodes_.back());
286}
287
289 absl::MutexLock mutex_lock(mu_);
290 return nodes_.size();
291}
292
294 absl::MutexLock mutex_lock(mu_);
295 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
296 if (!IsValid(path)) {
297 path.Clear();
298 return false;
299 }
300 DCHECK(CheckLratInvariants());
301 // We don't rely on these being empty, but we expect them to be.
302 DCHECK(to_close_.empty());
303 DCHECK(to_update_.empty());
304 int prev_level = -1;
305 for (const auto& [node, level] : nodes) {
306 if (level == prev_level) {
307 // `node` is implied by the previous decisions in `path`, hence its
308 // sibling can be closed (using this implication as proof; the implication
309 // proved by the worker providing `path` must be imported and a new one,
310 // adapted for the manager, must be inferred from it).
311 Node* sibling = GetSibling(node);
312 ClauseId closing_clause_id = kNoClauseId;
313 if (lrat_proof_handler_ != nullptr) {
314 // For the worker, `node` is implied by all the previous decisions in
315 // `path`, but for the manager we need an implication clause using the
316 // non-implied ancestors of `node` in the tree (they can be different
317 // because the manager and the worker have different views of the tree).
318 const std::vector<Literal> inferred_clause = ClosingClause(sibling);
319 std::vector<Literal> imported_clause;
320 std::vector<ClauseId> lrat_proof;
321 for (int l = 1; l <= level + 1; ++l) {
322 Node* n = l <= level ? GetNode(path.DecisionNodeId(l)) : node;
323 const Literal decision = DecodeWithIdentityMapping(n->decision);
324 imported_clause.push_back(l <= level ? decision.Negated() : decision);
325 if (n->implied_and_processed) {
326 lrat_proof.push_back(GetSibling(n)->closing_clause_id);
327 }
328 }
329 closing_clause_id = AddImportedAndInferredClauses(
330 imported_clause, inferred_clause, lrat_proof);
331 }
332 to_close_.emplace_back(sibling, closing_clause_id);
333 } else if (level > 0 && node->objective_lb < path.ObjectiveLb(level)) {
334 node->objective_lb = path.ObjectiveLb(level);
335 to_update_.push_back(node->parent);
336 }
337 if (level > 0 && !node->closed) {
338 NodeTrailInfo* trail_info = GetTrailInfo(node);
339 for (const ProtoLiteral& implication : path.Implications(level)) {
340 // Trivial implication, can be ignored.
341 if (IsDecisionOfNodeOrAncestor(implication, node)) continue;
342 ClauseId implication_clause_id = kNoClauseId;
343 if (lrat_proof_handler_ != nullptr) {
344 // For the worker, 'implication' is implied by all the previous
345 // decisions in `path`, but for the manager we need an implication
346 // clause using the non-implied ancestors of `node` in the tree (they
347 // can be different because the manager and the worker have different
348 // views of the tree).
349 const std::vector<Literal> inferred_clause =
350 ImplicationClause(node, implication);
351 std::vector<Literal> imported_clause;
352 std::vector<ClauseId> lrat_proof;
353 for (int l = 1; l <= level; ++l) {
354 Node* n = GetNode(path.DecisionNodeId(l));
355 const Literal decision = DecodeWithIdentityMapping(n->decision);
356 imported_clause.push_back(decision.Negated());
357 if (n->implied_and_processed) {
358 lrat_proof.push_back(GetSibling(n)->closing_clause_id);
359 }
360 }
361 imported_clause.push_back(DecodeWithIdentityMapping(implication));
362 implication_clause_id = AddImportedAndInferredClauses(
363 imported_clause, inferred_clause, lrat_proof);
364 }
365 auto it = trail_info->implications
366 .emplace(implication.proto_var(),
367 std::make_pair(implication.lb(),
368 implication_clause_id))
369 .first;
370 if (it->second.first < implication.lb()) {
371 it->second.first = implication.lb();
372 }
373 }
374 }
375 prev_level = level;
376 }
377 ProcessNodeChanges();
378 if (nodes.back().first->closed) {
379 path.Clear();
380 return false;
381 }
382 // Restart after processing updates - we might learn a new objective bound.
383 // Do initial restarts once each worker has had the chance to be assigned a
384 // leaf.
385 if (num_leaves_assigned_since_restart_ >= num_workers_ &&
386 num_restarts_ < kNumInitialRestarts) {
387 RestartLockHeld();
388 path.Clear();
389 return false;
390 }
391 // Sync lower bounds and implications from the shared tree to `path`.
392 AssignLeaf(path, nodes.back().first);
393 DCHECK(CheckLratInvariants());
394 return true;
395}
396
397int SharedTreeManager::TrySplitTree(absl::Span<const ProtoLiteral> decisions,
398 ProtoTrail& path) {
399 decisions = decisions.subspan(0, max_path_depth_ - path.MaxLevel());
400 if (decisions.empty()) return 0;
401 absl::MutexLock l(mu_);
402 for (int i = 0; i < decisions.size(); ++i) {
403 if (!TrySplitTreeLockHeld(decisions[i], path)) return i;
404 }
405 return decisions.size();
406}
407
408bool SharedTreeManager::TrySplitTreeLockHeld(ProtoLiteral decision,
409 ProtoTrail& path) {
410 if (!IsValid(path)) return false;
411 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
412 if (nodes.back().first->closed) {
413 VLOG(2) << "Cannot split closed node";
414 return false;
415 }
416 if (nodes.back().first->children[0] != nullptr) {
417 LOG_IF(WARNING, nodes.size() > 1)
418 << "Cannot resplit previously split node @ " << nodes.back().second
419 << "/" << nodes.size();
420 return false;
421 }
422 if (nodes_.size() + 2 > max_nodes_) {
423 VLOG(2) << "Too many nodes to accept split";
424 return false;
425 }
426 if (num_splits_wanted_ <= 0) {
427 VLOG(2) << "Enough splits for now";
428 return false;
429 }
430 for (const auto& [node, level] : nodes) {
431 if (decision == node->decision || decision == node->decision.Negated()) {
432 VLOG(2) << "Cannot split on decision which is already in the tree";
433 return false;
434 }
435 }
436 if (params_.shared_tree_split_strategy() ==
438 params_.shared_tree_split_strategy() ==
440 int discrepancy = 0;
441 for (const auto& [node, level] : nodes) {
442 if (node->parent == nullptr || node->implied) continue;
443 IntegerValue sibling_bound = GetSibling(node)->objective_lb;
444 discrepancy += (node->objective_lb == sibling_bound
445 ? node != node->parent->children[0]
446 : node->objective_lb > sibling_bound);
447 }
448 // TODO(user): Need to write up the shape this creates.
449 // This rule will allow twice as many leaves in the preferred subtree.
450 if (discrepancy + path.MaxLevel() >= max_path_depth_) {
451 VLOG(2) << "Too high discrepancy to accept split";
452 return false;
453 }
454 } else if (params_.shared_tree_split_strategy() ==
456 if (nodes.back().first->objective_lb > nodes.front().first->objective_lb) {
457 VLOG(2) << "Can only split nodes with minimum objective lb, "
458 << nodes.back().first->objective_lb << " > "
459 << nodes.front().first->objective_lb;
460 return false;
461 }
462 }
463 VLOG_EVERY_N(2, 10) << unassigned_leaves_.size() << " unassigned leaves, "
464 << nodes_.size() << " subtrees, " << num_splits_wanted_
465 << " splits wanted";
466 Split(nodes, decision);
467 auto [new_leaf, level] = nodes.back();
468 path.PushLevel(new_leaf->decision, new_leaf->objective_lb, new_leaf->id);
469 return true;
470}
471
473 absl::MutexLock mutex_lock(mu_);
474 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
475 if (nodes.back().first->children[0] == nullptr &&
476 !nodes.back().first->closed && nodes.size() > 1) {
477 Node* leaf = nodes.back().first;
478 VLOG(2) << "Returning leaf to be replaced";
479 GetTrailInfo(leaf)->phase = path.TakeTargetPhase();
480 unassigned_leaves_.push_back(leaf);
481 }
482 path.Clear();
483 while (!unassigned_leaves_.empty()) {
484 Node* leaf = unassigned_leaves_.front();
485 unassigned_leaves_.pop_front();
486 if (!leaf->closed && leaf->children[0] == nullptr) {
487 num_leaves_assigned_since_restart_ += 1;
488 AssignLeaf(path, leaf);
489 path.SetTargetPhase(std::move(GetTrailInfo(leaf)->phase));
490 return;
491 }
492 }
493 VLOG(2) << "Assigning root because no unassigned leaves are available";
494 // TODO(user): Investigate assigning a random leaf so workers can still
495 // improve shared tree bounds.
496}
497
498SharedTreeManager::NodeTrailInfo* SharedTreeManager::GetTrailInfo(Node* node) {
499 CHECK(node != nullptr && !node->closed);
500 while (node->trail_info == nullptr) {
501 node = node->parent;
502 }
503 CHECK_NE(node, nullptr);
504 return node->trail_info.get();
505}
506
507void SharedTreeManager::ClearTrailInfo(Node* node, bool implications_only) {
508 if (node->trail_info == nullptr) return;
509 if (lrat_proof_handler_ != nullptr) {
510 for (const auto& [var, lb_and_clause] : node->trail_info->implications) {
511 lrat_proof_handler_->DeleteClause(lb_and_clause.second, {});
512 }
513 }
514 if (implications_only) {
515 node->trail_info->implications.clear();
516 } else {
517 node->trail_info.reset();
518 }
519}
520
521SharedTreeManager::Node* SharedTreeManager::GetSibling(const Node* node) const {
522 if (node == nullptr || node->parent == nullptr) return nullptr;
523 if (node->parent->children[0] != node) {
524 return node->parent->children[0];
525 }
526 return node->parent->children[1];
527}
528
529void SharedTreeManager::Split(std::vector<std::pair<Node*, int>>& nodes,
530 ProtoLiteral lit) {
531 const auto [parent, level] = nodes.back();
532 DCHECK(parent->children[0] == nullptr);
533 DCHECK(parent->children[1] == nullptr);
534 parent->children[0] = MakeSubtree(parent, lit);
535 parent->children[1] = MakeSubtree(parent, lit.Negated());
536 NodeTrailInfo* trail_info = GetTrailInfo(parent);
537 if (trail_info != nullptr) {
538 parent->children[0]->trail_info =
539 std::make_unique<NodeTrailInfo>(NodeTrailInfo{});
540 parent->children[1]->trail_info = std::make_unique<NodeTrailInfo>(
541 NodeTrailInfo{.phase = std::move(trail_info->phase)});
542 }
543 nodes.push_back(std::make_pair(parent->children[0], level + 1));
544 unassigned_leaves_.push_back(parent->children[1]);
545 --num_splits_wanted_;
546}
547
548SharedTreeManager::Node* SharedTreeManager::MakeSubtree(Node* parent,
549 ProtoLiteral decision) {
550 nodes_.push_back(
551 Node{.decision = decision,
552 .objective_lb = parent->objective_lb,
553 .parent = parent,
554 .id = static_cast<int>(nodes_.size() + node_id_offset_)});
555 return &nodes_.back();
556}
557
558void SharedTreeManager::ProcessNodeChanges() {
559 DCHECK(CheckLratInvariants());
560 int num_newly_closed = 0;
561 std::vector<Node*> newly_implied;
562 while (!to_close_.empty()) {
563 auto [node, closing_clause_id] = to_close_.back();
564 CHECK_NE(node, nullptr);
565 to_close_.pop_back();
566 // Iterate over open parents while each sibling is closed.
567 while (node != nullptr && !node->closed) {
568 ++num_newly_closed;
569 ++num_closed_nodes_;
570 node->closed = true;
571 node->closing_clause_id = closing_clause_id;
572 // Keep the root trail_info so GetTrailInfo never returns nullptr.
573 if (node->parent != nullptr) {
574 ClearTrailInfo(node);
575 }
576 node->objective_lb = kMaxIntegerValue;
577 // If we are closing a leaf, try to maintain the same number of leaves;
578 num_splits_wanted_ += (node->children[0] == nullptr);
579 for (Node* child : node->children) {
580 if (child == nullptr || child->closed) continue;
581 ClauseId child_closing_clause_id = kNoClauseId;
582 if (lrat_proof_handler_ != nullptr) {
583 // The node's closing clause is sufficient to prove that `child` can
584 // be closed. We use a new clause only to avoid double deletes in
585 // RestartLockHeld().
586 child_closing_clause_id = clause_id_generator_.GetNextId();
587 lrat_proof_handler_->AddInferredClause(
588 child_closing_clause_id, ClosingClause(child),
589 {closing_clause_id}, /*exported=*/true);
590 }
591 to_close_.emplace_back(child, child_closing_clause_id);
592 }
593 Node* sibling = GetSibling(node);
594 if (sibling != nullptr) {
595 sibling->implied = true;
596 if (lrat_proof_handler_ != nullptr) {
597 newly_implied.push_back(sibling);
598 }
599 if (!sibling->closed) {
600 break;
601 }
602 }
603 Node* parent = node->parent;
604 if (lrat_proof_handler_ != nullptr && parent != nullptr &&
605 !parent->closed) {
606 closing_clause_id = clause_id_generator_.GetNextId();
607 // Combine the clauses proving that the node and its sibling could be
608 // closed to prove that the parent can be closed.
609 lrat_proof_handler_->AddInferredClause(
610 closing_clause_id, ClosingClause(parent),
611 {node->closing_clause_id, sibling->closing_clause_id},
612 /*exported=*/true);
613 }
614 node = parent;
615 }
616 DCHECK(node == nullptr || node->closed);
617 if (node == nullptr) {
618 shared_response_manager_->NotifyThatImprovingProblemIsInfeasible(
619 ShortStatus());
620 } else if (node->parent != nullptr) {
621 to_update_.push_back(node->parent);
622 }
623 }
624 if (num_newly_closed > 0) {
625 shared_response_manager_->LogMessageWithThrottling(
626 "Tree", absl::StrCat("closed:", num_closed_nodes_, "/", nodes_.size(),
627 " unassigned:", unassigned_leaves_.size(),
628 " restarts:", num_restarts_));
629 }
630 DCHECK(CheckLratInvariants());
631 // TODO(user): We could do resolution here by moving implications that
632 // are true in each child to the parent.
633 bool root_updated = false;
634 while (!to_update_.empty()) {
635 Node* node = to_update_.back();
636 to_update_.pop_back();
637 // Iterate over parents while the lower bound can be improved.
638 while (node != nullptr && !node->closed) {
639 DCHECK(node->children[0] != nullptr);
640 DCHECK(node->children[1] != nullptr);
641 for (Node* child : node->children) {
642 if (child->implied) {
643 if (child->trail_info != nullptr) {
644 DCHECK(!child->implied_and_processed);
645 ProcessImpliedNode(child);
646 ClearTrailInfo(child);
647 }
648 child->implied_and_processed = true;
649 }
650 }
651 IntegerValue child_bound = std::min(node->children[0]->objective_lb,
652 node->children[1]->objective_lb);
653 if (child_bound <= node->objective_lb) break;
654 node->objective_lb = child_bound;
655 node = node->parent;
656 }
657 if (node == nullptr) root_updated = true;
658 }
659 if (root_updated) {
660 shared_response_manager_->UpdateInnerObjectiveBounds(
661 ShortStatus(), nodes_[0].objective_lb, kMaxIntegerValue);
662 }
663 for (Node* node : newly_implied) {
664 if (!node->implied_and_processed) {
665 DCHECK_EQ(node->trail_info, nullptr);
666 DCHECK_NE(lrat_proof_handler_, nullptr);
667 ProcessImpliedNode(node);
668 node->implied_and_processed = true;
669 }
670 }
671 // These are shared via SharedBoundsManager, don't duplicate here.
672 ClearTrailInfo(&nodes_[0], /*implications_only=*/true);
673 DCHECK(CheckLratInvariants());
674}
675
676// Moves the trail_info implications of `node` to its first non-implied
677// ancestor, and removes the newly implied literal from the closing and
678// implication clauses of `node` and its descendants.
679void SharedTreeManager::ProcessImpliedNode(Node* node) {
680 CHECK(node->parent != nullptr);
681 Node* first_non_implied_ancestor = node->parent;
682 while (first_non_implied_ancestor->trail_info == nullptr) {
683 first_non_implied_ancestor = first_non_implied_ancestor->parent;
684 DCHECK_NE(first_non_implied_ancestor, nullptr);
685 }
686 // Fast path for the common case where there is no need to add LRAT clauses.
687 // The rest of the code is only executed when LRAT is enabled, and assumes a
688 // pure SAT problem.
689 if (lrat_proof_handler_ == nullptr) {
690 first_non_implied_ancestor->trail_info->implications.merge(
691 node->trail_info->implications);
692 return;
693 }
694 // Gather the clauses needed to prove the new implications and closing
695 // clauses.
696 std::vector<ClauseId> clauses;
697 Node* n = node;
698 while (n->parent != nullptr) {
699 // Newly implied nodes must be removed from the closing and implication
700 // clauses, which requires a proof (already implied nodes are no longer in
701 // these clauses, so we don't need a proof for them).
702 if (n->implied && !n->implied_and_processed) {
703 clauses.push_back(GetSibling(n)->closing_clause_id);
704 }
705 n = n->parent;
706 }
707 std::reverse(clauses.begin(), clauses.end());
708 // Move the implications of `node` to the first non-implied ancestor.
709 if (node->trail_info != nullptr) {
710 for (const auto& [var, lb_and_clause] : node->trail_info->implications) {
711 // This is OK because we assume a pure SAT problem.
712 if (first_non_implied_ancestor->trail_info->implications.contains(var)) {
713 continue;
714 }
715 const auto [lb, clause_id] = lb_and_clause;
716 ClauseId new_clause_id = clause_id_generator_.GetNextId();
717 clauses.push_back(clause_id);
718 lrat_proof_handler_->AddInferredClause(
719 new_clause_id,
720 ImplicationClause(first_non_implied_ancestor, ProtoLiteral(var, lb),
721 /*skip_unprocessed_implied_nodes=*/true),
722 clauses, /*exported=*/true);
723 clauses.pop_back();
724 first_non_implied_ancestor->trail_info->implications.insert(
725 {var, std::make_pair(lb, new_clause_id)});
726 }
727 }
728 UpdateLratClausesInSubtree(node, node, clauses);
729}
730
731// Updates the closing clauses and the trail implication clauses of all the
732// nodes in the subtree rooted at `node`, to maintain the LRAT invariants.
733// Recursive method where `n` is a node of the subtree, and `clauses` are the
734// clauses needed to infer its updated closing and implication clauses.
735// TODO(user): change to a non-recursive implementation?
736void SharedTreeManager::UpdateLratClausesInSubtree(
737 Node* node, Node* n, std::vector<ClauseId>& clauses) {
738 const bool implied_and_not_processed =
739 n->implied && !n->implied_and_processed;
740 if (implied_and_not_processed) {
741 // Newly implied nodes must be removed from the closing and implication
742 // clauses of `n`, which requires a proof (already implied nodes are no
743 // longer in these clauses, so we don't need a proof for them).
744 clauses.push_back(GetSibling(n)->closing_clause_id);
745 }
746 if (n->closed) {
747 DCHECK_NE(n->closing_clause_id, kNoClauseId);
748 ClauseId new_clause_id = clause_id_generator_.GetNextId();
749 clauses.push_back(n->closing_clause_id);
750 lrat_proof_handler_->AddInferredClause(
751 new_clause_id,
752 ClosingClause(n, /*skip_unprocessed_implied_nodes=*/true), clauses,
753 /*exported=*/true);
754 clauses.pop_back();
755 lrat_proof_handler_->DeleteClause(n->closing_clause_id, {});
756 n->closing_clause_id = new_clause_id;
757 }
758 if (n != node && n->trail_info != nullptr) {
759 for (auto& [var, lb_and_clause] : n->trail_info->implications) {
760 auto& [lb, clause_id] = lb_and_clause;
761 ClauseId new_clause_id = clause_id_generator_.GetNextId();
762 clauses.push_back(clause_id);
763 lrat_proof_handler_->AddInferredClause(
764 new_clause_id,
765 ImplicationClause(n, ProtoLiteral(var, lb),
766 /*skip_unprocessed_implied_nodes=*/true),
767 clauses, /*exported=*/true);
768 lrat_proof_handler_->DeleteClause(clause_id, {});
769 clause_id = new_clause_id;
770 clauses.pop_back();
771 }
772 }
773 // We can stop at implied but not yet processed nodes (they will be processed
774 // with further calls to ProcessImpliedNode()).
775 if (n == node || !(n->implied && n->trail_info != nullptr)) {
776 for (Node* child : n->children) {
777 if (child != nullptr && child->parent != nullptr) {
778 UpdateLratClausesInSubtree(node, child, clauses);
779 }
780 }
781 }
782 if (implied_and_not_processed) {
783 clauses.pop_back();
784 }
785}
786
787SharedTreeManager::Node* SharedTreeManager::GetNode(int id) {
788 const int index = id - node_id_offset_;
789 CHECK_GE(index, 0);
790 CHECK_LT(index, nodes_.size());
791 return &nodes_[index];
792}
793
794std::vector<std::pair<SharedTreeManager::Node*, int>>
795SharedTreeManager::GetAssignedNodes(const ProtoTrail& path) {
796 std::vector<std::pair<Node*, int>> nodes({std::make_pair(&nodes_[0], 0)});
797 if (!IsValid(path)) {
798 // Restart has happened, nodes in this path are no longer valid, but the
799 // root is equivalent.
800 return nodes;
801 }
802 for (int i = 0; i <= path.MaxLevel(); ++i) {
803 for (int id : path.NodeIds(i)) {
804 const int index = id - node_id_offset_;
805 CHECK_GE(index, 0) << " in path.NodeIds(" << i
806 << "), max_level=" << path.MaxLevel();
807 CHECK_LT(index, nodes_.size());
808 DCHECK_EQ(nodes.back().first, nodes_[index].parent);
809 nodes.push_back(std::make_pair(&nodes_[index], i));
810 }
811 }
812 return nodes;
813}
814
816 absl::MutexLock mutex_lock(mu_);
817 DCHECK(CheckLratInvariants());
818 const int node_id_to_close = path.NodeIds(level).front();
819 if (node_id_to_close < node_id_offset_) {
820 path.Clear();
821 return;
822 }
823 Node* node = &nodes_[node_id_to_close - node_id_offset_];
824 VLOG(2) << "Closing subtree at level " << level;
825 DCHECK(to_close_.empty());
826
827 ClauseId closing_clause_id = kNoClauseId;
828 if (lrat_proof_handler_ != nullptr) {
829 // For the worker providing `path`, `node` is implied by all the previous
830 // decisions in `path`, but for the manager we need a closing clause using
831 // `node` and its ancestors in the tree (with implied ones filtered out --
832 // they can be different because the manager and the worker have different
833 // views of the tree).
834 const std::vector<Literal> inferred_clause = ClosingClause(node);
835 std::vector<Literal> imported_clause;
836 std::vector<ClauseId> lrat_proof;
837 for (int l = 1; l <= level; ++l) {
838 Node* n = GetNode(path.DecisionNodeId(l));
839 const Literal decision = DecodeWithIdentityMapping(n->decision);
840 imported_clause.push_back(decision.Negated());
841 if (n->implied_and_processed) {
842 lrat_proof.push_back(GetSibling(n)->closing_clause_id);
843 }
844 }
845 closing_clause_id = AddImportedAndInferredClauses(
846 imported_clause, inferred_clause, lrat_proof);
847 }
848 path.Clear();
849 to_close_.emplace_back(node, closing_clause_id);
850 ProcessNodeChanges();
851 DCHECK(CheckLratInvariants());
852}
853
854bool SharedTreeManager::IsDecisionOfNodeOrAncestor(ProtoLiteral literal,
855 const Node* node) const {
856 CHECK_NE(node, nullptr);
857 while (node->parent != nullptr) {
858 if (literal == node->decision) return true;
859 node = node->parent;
860 }
861 return false;
862}
863
864std::vector<Literal> SharedTreeManager::ImplicationClause(
865 const Node* node, ProtoLiteral implied,
866 bool skip_unprocessed_implied_nodes) const {
867 // This is only used for LRAT, which only works for pure SAT, without
868 // presolve. In this case all workers should have the same identity mapping
869 // from the proto variables.
870 CHECK_NE(node, nullptr);
871 std::vector<Literal> clause =
872 ClosingClause(node, skip_unprocessed_implied_nodes);
873 clause.push_back(DecodeWithIdentityMapping(implied));
874 return clause;
875}
876
877std::vector<Literal> SharedTreeManager::ClosingClause(
878 const Node* node, bool skip_unprocessed_implied_nodes) const {
879 // This is only used for LRAT, which only works for pure SAT, without
880 // presolve. In this case all workers should have the same identity mapping
881 // from the proto variables.
882 CHECK_NE(node, nullptr);
883 std::vector<Literal> clause;
884 while (node->parent != nullptr) {
885 // When a node is implied its implications are moved to its first
886 // non-implied ancestor, instead of to its parent. Proving this with the
887 // clause that the node is implied requires the implication clauses to
888 // exclude the decisions of implied nodes. And since the clause that a node
889 // is implied is the closing clause of its sibling, closing clauses should
890 // also exclude the decisions of implied nodes.
891 const bool is_implied = node->implied && (node->implied_and_processed ||
892 skip_unprocessed_implied_nodes);
893 if (!is_implied) {
894 clause.push_back(DecodeWithIdentityMapping(node->decision).Negated());
895 }
896 node = node->parent;
897 }
898 return clause;
899}
900
901namespace {
902bool UnorderedSpansAreEqual(absl::Span<const Literal> a,
903 absl::Span<const Literal> b) {
904 if (a.size() != b.size()) return false;
905 std::vector<Literal> sorted_a(a.begin(), a.end());
906 std::vector<Literal> sorted_b(b.begin(), b.end());
907 std::sort(sorted_a.begin(), sorted_a.end());
908 std::sort(sorted_b.begin(), sorted_b.end());
909 return sorted_a == sorted_b;
910}
911} // namespace
912
913ClauseId SharedTreeManager::AddImportedAndInferredClauses(
914 absl::Span<const Literal> imported_clause,
915 absl::Span<const Literal> inferred_clause,
916 std::vector<ClauseId>& lrat_proof) {
917 const ClauseId id = clause_id_generator_.GetNextId();
918 lrat_proof_handler_->AddImportedClause(id, imported_clause);
919 if (!lrat_proof.empty() ||
920 !UnorderedSpansAreEqual(inferred_clause, imported_clause)) {
921 lrat_proof.push_back(id);
922 const ClauseId new_id = clause_id_generator_.GetNextId();
923 lrat_proof_handler_->AddInferredClause(new_id, inferred_clause, lrat_proof,
924 /*exported=*/true);
925 lrat_proof_handler_->DeleteClause(id, {});
926 return new_id;
927 } else {
928 return id;
929 }
930}
931
932void SharedTreeManager::AssignLeaf(ProtoTrail& path, Node* leaf) {
933 path.Clear();
934 std::vector<Node*> reversed_path;
935 while (leaf != &nodes_[0]) {
936 reversed_path.push_back(&nodes_[leaf->id - node_id_offset_]);
937 leaf = leaf->parent;
938 }
939 while (!reversed_path.empty()) {
940 Node* leaf = reversed_path.back();
941 reversed_path.pop_back();
942 path.PushLevel(leaf->decision, leaf->objective_lb, leaf->id);
943 if (leaf->implied) {
944 path.SetLevelImplied(path.MaxLevel());
945 }
946 if (params_.shared_tree_worker_enable_trail_sharing() &&
947 leaf->trail_info != nullptr) {
948 for (const auto& [var, lb_and_clause] : leaf->trail_info->implications) {
949 const auto [lb, clause_id] = lb_and_clause;
950 path.AddImplication(path.MaxLevel(), ProtoLiteral(var, lb));
951 }
952 }
953 }
954}
955
956bool SharedTreeManager::IsValid(const ProtoTrail& path) const {
957 auto node_ids = path.NodeIds(path.MaxLevel());
958 if (node_ids.empty()) return true;
959 if (node_ids.back() < node_id_offset_) return false;
960 return true;
961}
962
963void SharedTreeManager::RestartLockHeld() {
964 node_id_offset_ += nodes_.size();
965 if (lrat_proof_handler_ != nullptr) {
966 for (const Node& node : nodes_) {
967 if (node.closing_clause_id != kNoClauseId) {
968 lrat_proof_handler_->DeleteClause(node.closing_clause_id, {});
969 }
970 }
971 }
972 nodes_.resize(1);
973 nodes_[0].id = node_id_offset_;
974 nodes_[0].children = {nullptr, nullptr};
975 unassigned_leaves_.clear();
976 DCHECK(to_close_.empty());
977 DCHECK(to_update_.empty());
978 num_splits_wanted_ =
979 num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1;
980 num_closed_nodes_ = 0;
981 num_restarts_ += 1;
982 num_leaves_assigned_since_restart_ = 0;
983}
984
986 absl::MutexLock l(mu_);
987 if (lrat_proof_handler_ != nullptr) {
988 lrat_proof_handler_->Close(/*model_is_unsat=*/false);
989 lrat_proof_handler_.reset();
990 }
991}
992
993std::string SharedTreeManager::ShortStatus() const {
994 return absl::StrCat("shared_tree_manager(r=", num_restarts_,
995 " n=", nodes_.size(), ")");
996}
997
998namespace {
999void CheckEqual(absl::Span<const Literal> a, absl::Span<const Literal> b) {
1000 std::vector<Literal> sorted_a(a.begin(), a.end());
1001 std::vector<Literal> sorted_b(b.begin(), b.end());
1002 std::sort(sorted_a.begin(), sorted_a.end());
1003 std::sort(sorted_b.begin(), sorted_b.end());
1004 CHECK_EQ(sorted_a, sorted_b);
1005}
1006} // namespace
1007
1008bool SharedTreeManager::CheckLratInvariants() const {
1009 if (lrat_proof_handler_ != nullptr &&
1010 lrat_proof_handler_->lrat_check_enabled()) {
1011 for (const Node& node : nodes_) {
1012 if (node.parent == nullptr) continue;
1013 if (node.closed) {
1014 CheckEqual(
1015 lrat_proof_handler_->GetLratClauseForDebug(node.closing_clause_id),
1016 ClosingClause(&node));
1017 }
1018 if (node.trail_info != nullptr) {
1019 for (const auto& [var, lb_and_clause] : node.trail_info->implications) {
1020 const auto [lb, clause_id] = lb_and_clause;
1021 CheckEqual(lrat_proof_handler_->GetLratClauseForDebug(clause_id),
1022 ImplicationClause(&node, ProtoLiteral(var, lb)));
1023 }
1024 }
1025 }
1026 }
1027 return true;
1028}
1029
1031 : parameters_(model->GetOrCreate<SatParameters>()),
1032 shared_response_(model->GetOrCreate<SharedResponseManager>()),
1033 time_limit_(model->GetOrCreate<TimeLimit>()),
1034 manager_(model->GetOrCreate<SharedTreeManager>()),
1035 mapping_(model->GetOrCreate<CpModelMapping>()),
1036 sat_solver_(model->GetOrCreate<SatSolver>()),
1037 trail_(model->GetOrCreate<Trail>()),
1038 binary_propagator_(model->GetOrCreate<BinaryImplicationGraph>()),
1039 clause_manager_(model->GetOrCreate<ClauseManager>()),
1040 clause_id_generator_(model->GetOrCreate<ClauseIdGenerator>()),
1041 lrat_proof_handler_(model->Mutable<LratProofHandler>()),
1042 integer_trail_(model->GetOrCreate<IntegerTrail>()),
1043 encoder_(model->GetOrCreate<IntegerEncoder>()),
1044 objective_(model->Get<ObjectiveDefinition>()),
1045 random_(model->GetOrCreate<ModelRandomGenerator>()),
1046 helper_(model->GetOrCreate<IntegerSearchHelper>()),
1047 heuristics_(model->GetOrCreate<SearchHeuristics>()),
1048 decision_policy_(model->GetOrCreate<SatDecisionPolicy>()),
1049 restart_policy_(model->GetOrCreate<RestartPolicy>()),
1050 level_zero_callbacks_(model->GetOrCreate<LevelZeroCallbackHelper>()),
1051 reversible_int_repository_(model->GetOrCreate<RevIntRepository>()),
1052 assigned_tree_lbds_(/*window_size=*/8) {}
1053
1054std::vector<Literal>& SharedTreeWorker::DecisionReason(int level) {
1055 CHECK_LE(level, assigned_tree_decisions_.size());
1056 reason_.clear();
1057 for (int i = 0; i < level; ++i) {
1058 reason_.push_back(assigned_tree_decisions_[i].Negated());
1059 }
1060 return reason_;
1061}
1062
1063bool SharedTreeWorker::AddDecisionImplication(Literal lit, int level,
1064 ClauseId clause_id) {
1065 CHECK_GT(level, 0);
1066 CHECK_NE(lit.Index(), kNoLiteralIndex);
1067 CHECK(!sat_solver_->Assignment().LiteralIsTrue(lit));
1068 absl::Span<const Literal> reason = DecisionReason(level);
1069 if (sat_solver_->Assignment().LiteralIsFalse(lit)) {
1070 VLOG(2) << "Closing subtree via impl at " << level + 1
1071 << " assigned=" << assigned_tree_.MaxLevel();
1072 if (lrat_proof_handler_ != nullptr) {
1073 // Use the fact that `reason` implies both `lit` and not(`lit`) to prove
1074 // that the tree can be closed.
1075 const ClauseId closing_clause_id = clause_id_generator_->GetNextId();
1076 std::vector<ClauseId> clause_ids;
1077 clause_manager_->AppendClauseIdsFixing({lit}, &clause_ids);
1078 clause_ids.push_back(clause_id);
1079 lrat_proof_handler_->AddInferredClause(closing_clause_id,
1080 DecisionReason(level), clause_ids,
1081 /*exported=*/true);
1082 lrat_proof_handler_->DeleteClause(closing_clause_id, {});
1083 }
1084 trail_->MutableConflict()->assign(reason.begin(), reason.end());
1085 manager_->CloseTree(assigned_tree_, level);
1086 assigned_tree_decisions_.clear();
1087 return false;
1088 }
1089 VLOG(2) << "Learned shared clause";
1090 trail_->GetEmptyVectorToStoreReason()->assign(reason.begin(), reason.end());
1091 return trail_->EnqueueWithStoredReason(clause_id, lit);
1092}
1093
1094bool SharedTreeWorker::AddImplications() {
1095 const int level = sat_solver_->CurrentDecisionLevel();
1096 // Level 0 implications are unit clauses and are synced elsewhere.
1097 if (level == 0) return false;
1098 if (level > assigned_tree_.MaxLevel()) {
1099 return false;
1100 }
1101 rev_num_processed_implications_.resize(level + 1, 0);
1102 auto& num_processed_implications = rev_num_processed_implications_[level];
1103 reversible_int_repository_->SaveState(&num_processed_implications);
1104 absl::Span<const std::pair<Literal, ClauseId>> implied_literals =
1105 absl::MakeConstSpan(assigned_tree_implications_[level - 1])
1106 .subspan(num_processed_implications);
1107 bool added_clause = false;
1108 for (const auto& [implied, clause_id] : implied_literals) {
1109 ++num_processed_implications;
1110 if (sat_solver_->Assignment().LiteralIsTrue(implied)) continue;
1111 added_clause = true;
1112 if (!AddDecisionImplication(implied, level, clause_id)) return true;
1113 }
1114 if (objective_ != nullptr &&
1115 objective_->objective_var != kNoIntegerVariable) {
1116 const IntegerValue obj_lb =
1117 integer_trail_->LowerBound(objective_->objective_var);
1118 assigned_tree_.SetObjectiveLb(level, obj_lb);
1119 const Literal obj_lit =
1120 encoder_->GetOrCreateAssociatedLiteral(IntegerLiteral::GreaterOrEqual(
1121 objective_->objective_var, assigned_tree_.ObjectiveLb(level)));
1122 if (!sat_solver_->Assignment().LiteralIsTrue(obj_lit)) {
1123 AddDecisionImplication(obj_lit, level, kNoClauseId);
1124 return true;
1125 }
1126 }
1127 DCHECK(CheckLratInvariants());
1128 return added_clause;
1129}
1130
1131void SharedTreeWorker::ClearAssignedTreeDecisionsAndImplications() {
1132 // Delete all LRAT clauses corresponding to the assigned tree implications,
1133 // which are deleted too. Note that there is one LRAT proof per worker. Each
1134 // proof uses its local clause IDs, and there is no global clause ID space.
1135 // Individual proofs can be merged at the end of the solve, if UNSAT. In this
1136 // case clause deletions of individual proofs are ignored until the clause is
1137 // no longer needed by any other partial proof. Hence it is safe to delete the
1138 // clauses here, even if they are still needed in the SharedTreeManager.
1139 if (lrat_proof_handler_ != nullptr) {
1140 for (const auto& implications : assigned_tree_implications_) {
1141 for (const auto& [literal, clause_id] : implications) {
1142 lrat_proof_handler_->DeleteClause(clause_id, {});
1143 }
1144 }
1145 }
1146 assigned_tree_decisions_.clear();
1147 assigned_tree_implications_.clear();
1148}
1149
1150bool SharedTreeWorker::SyncWithLocalTrail() {
1151 DCHECK(CheckLratInvariants());
1152 std::vector<int> new_implication_trail_indices;
1153 while (true) {
1154 if (lrat_proof_handler_ != nullptr) {
1155 trail_implication_clauses_.resize(reversible_trail_index_, kNoClauseId);
1156 }
1157 if (!sat_solver_->FinishPropagation()) return false;
1158 // Ensure we are at fixed point w.r.t. implications in the tree up to the
1159 // current level.
1160 if (AddImplications()) continue;
1161
1162 if (!helper_->BeforeTakingDecision()) return false;
1163 const int level = sat_solver_->CurrentDecisionLevel();
1164 if (parameters_->shared_tree_worker_enable_trail_sharing() && level > 0 &&
1165 level <= assigned_tree_.MaxLevel() &&
1166 reversible_trail_index_ < trail_->Index()) {
1167 const int binary_propagator_id = binary_propagator_->PropagatorId();
1168 // Add implications from the local trail to share with other workers.
1169 reversible_int_repository_->SaveState(&reversible_trail_index_);
1170 new_implication_trail_indices.clear();
1171 for (int i = trail_->Index() - 1; i >= reversible_trail_index_; --i) {
1172 const Literal lit = (*trail_)[i];
1173 const int assignment_type = trail_->AssignmentType(lit.Variable());
1174 if (assignment_type == AssignmentType::kSearchDecision) break;
1175 // Avoid sharing implications from binary clauses - these are always
1176 // shared, so the implication will be propagated anyway.
1177 if (assignment_type == binary_propagator_id) continue;
1178 std::optional<ProtoLiteral> encoded = EncodeDecision(lit);
1179 if (!encoded.has_value()) continue;
1180 if (assigned_tree_.AddImplication(level, *encoded) &&
1181 lrat_proof_handler_ != nullptr) {
1182 new_implication_trail_indices.push_back(i);
1183 }
1184 }
1185 // Add LRAT inferred clauses for the new implications, so that other
1186 // workers can import them without proof. Do this in increasing trail
1187 // index order, and reuse the previously added clauses to prove the new
1188 // ones (to avoid a quadratic complexity).
1189 if (lrat_proof_handler_ != nullptr) {
1190 trail_implication_clauses_.resize(trail_->Index(), kNoClauseId);
1191 for (int i = new_implication_trail_indices.size() - 1; i >= 0; --i) {
1192 const int new_trail_index = new_implication_trail_indices[i];
1193 const Literal lit = (*trail_)[new_trail_index];
1194 trail_implication_clauses_[new_trail_index] =
1195 AddLratClauseAndProofForImplication(
1196 lit, level, [&](int /*level*/, int trail_index) {
1197 return trail_implication_clauses_[trail_index];
1198 });
1199 }
1200 }
1201 reversible_trail_index_ = trail_->Index();
1202 }
1203 if (level >= assigned_tree_.MaxLevel()) break;
1204 // The next decision is assigned, make sure it makes sense.
1205 const Literal next_decision = assigned_tree_decisions_[level];
1206 if (!sat_solver_->Assignment().LiteralIsAssigned(next_decision)) break;
1207 if (sat_solver_->Assignment().LiteralIsFalse(next_decision)) {
1208 // Next assigned decision is impossible.
1209 VLOG(2) << "Closing subtree at " << level + 1
1210 << " assigned=" << assigned_tree_.MaxLevel();
1211 // Add the LRAT inferred clause "current decisions => not(next_decision)"
1212 // so that it can be imported in SharedTreeManager to close the tree. We
1213 // can delete it right away since we don't need it in the worker itself.
1214 const ClauseId clause_id =
1215 AddLratClauseAndProofForImplication(next_decision.Negated(), level);
1216 manager_->CloseTree(assigned_tree_, level + 1);
1217 if (lrat_proof_handler_ != nullptr) {
1218 lrat_proof_handler_->DeleteClause(clause_id, {});
1219 }
1220 ClearAssignedTreeDecisionsAndImplications();
1221 sat_solver_->Backtrack(0);
1222 } else {
1223 // The next level is implied by the current one.
1224 if (lrat_proof_handler_ != nullptr) {
1225 // Update the LRAT clause of each implied literal at any next level, in
1226 // order to remove `next_decision` from these implications. Each new
1227 // clause is proved with the old one, combined with the clause that the
1228 // current decisions imply the next one.
1229 const ClauseId clause_id =
1230 AddLratClauseAndProofForImplication(next_decision, level);
1231 std::vector<Literal> implication;
1232 for (int i = 0; i < level; ++i) {
1233 implication.push_back(assigned_tree_decisions_[i].Negated());
1234 }
1235 for (int l = level; l < assigned_tree_decisions_.size(); ++l) {
1236 if (l != level) {
1237 implication.push_back(assigned_tree_decisions_[l].Negated());
1238 }
1239 for (auto& [lit, id] : assigned_tree_implications_[l]) {
1240 const ClauseId old_id = id;
1241 id = clause_id_generator_->GetNextId();
1242 implication.push_back(lit);
1243 lrat_proof_handler_->AddInferredClause(
1244 id, implication, {clause_id, old_id}, /*exported=*/true);
1245 lrat_proof_handler_->DeleteClause(old_id, {});
1246 implication.pop_back();
1247 }
1248 }
1249 lrat_proof_handler_->DeleteClause(clause_id, {});
1250 }
1251 assigned_tree_.SetLevelImplied(level + 1);
1252 if (level > 0) {
1253 assigned_tree_implications_[level - 1].insert(
1254 assigned_tree_implications_[level - 1].end(),
1255 assigned_tree_implications_[level].begin(),
1256 assigned_tree_implications_[level].end());
1257 }
1258 assigned_tree_implications_.erase(assigned_tree_implications_.begin() +
1259 level);
1260 assigned_tree_decisions_.erase(assigned_tree_decisions_.begin() + level);
1261 }
1262 }
1263 DCHECK(CheckLratInvariants());
1264 return true;
1265}
1266
1267ClauseId SharedTreeWorker::AddLratClauseAndProofForImplication(
1268 Literal literal, int level,
1269 std::optional<absl::FunctionRef<ClauseId(int, int)>> root_literals) {
1270 if (lrat_proof_handler_ == nullptr) return kNoClauseId;
1271
1272 CHECK_LE(level, assigned_tree_decisions_.size());
1273 const ClauseId clause_id = clause_id_generator_->GetNextId();
1274 std::vector<Literal>& implication = DecisionReason(level);
1275 implication.push_back(literal);
1276 std::vector<ClauseId> clause_ids;
1277 clause_manager_->AppendClauseIdsFixing(
1278 {literal}, &clause_ids, /*decision=*/kNoLiteralIndex, root_literals);
1279 lrat_proof_handler_->AddInferredClause(clause_id, implication, clause_ids,
1280 /*exported=*/true);
1281 return clause_id;
1282}
1283
1284ClauseId SharedTreeWorker::ImportLratClauseForImplication(Literal literal,
1285 int level) {
1286 if (lrat_proof_handler_ == nullptr) return kNoClauseId;
1287
1288 CHECK_LE(level, assigned_tree_decisions_.size());
1289 const ClauseId clause_id = clause_id_generator_->GetNextId();
1290 std::vector<Literal>& implication = DecisionReason(level);
1291 implication.push_back(literal);
1292 lrat_proof_handler_->AddImportedClause(clause_id, implication);
1293 return clause_id;
1294}
1295
1296bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) {
1297 const auto& decision_policy =
1298 heuristics_->decision_policies[heuristics_->policy_index];
1299 const int next_level = sat_solver_->CurrentDecisionLevel() + 1;
1300 CHECK_EQ(assigned_tree_decisions_.size(), assigned_tree_.MaxLevel());
1301 if (next_level <= assigned_tree_.MaxLevel()) {
1302 VLOG(2) << "Following shared trail depth=" << next_level << " "
1303 << parameters_->name();
1304 const Literal decision = assigned_tree_decisions_[next_level - 1];
1305 CHECK(!sat_solver_->Assignment().LiteralIsFalse(decision))
1306 << " at depth " << next_level << " " << parameters_->name();
1307 CHECK(!sat_solver_->Assignment().LiteralIsTrue(decision));
1308 *decision_index = decision.Index();
1309 return true;
1310 }
1311 return helper_->GetDecision(decision_policy, decision_index);
1312}
1313
1314void SharedTreeWorker::MaybeProposeSplits() {
1315 if (time_limit_->GetElapsedDeterministicTime() <= next_split_dtime_) {
1316 return;
1317 }
1318 next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() +
1319 parameters_->shared_tree_split_min_dtime();
1320 tmp_splits_.clear();
1321 const int max_split_level =
1322 std::min<int>(trail_->CurrentDecisionLevel(), manager_->MaxPathDepth());
1323 for (int i = assigned_tree_.MaxLevel(); i < max_split_level; ++i) {
1324 const Literal split_decision = trail_->Decisions()[i].literal;
1325 const std::optional<ProtoLiteral> encoded = EncodeDecision(split_decision);
1326 if (!encoded.has_value()) break;
1327 tmp_splits_.push_back(*encoded);
1328 }
1329 const int splits_accepted =
1330 manager_->TrySplitTree(tmp_splits_, assigned_tree_);
1331 for (int i = 0; i < splits_accepted; ++i) {
1332 assigned_tree_decisions_.push_back(DecodeDecision(tmp_splits_[i]));
1333 assigned_tree_implications_.push_back({});
1334 }
1335}
1336
1337bool SharedTreeWorker::ShouldReplaceSubtree() {
1338 // If we have no assignment, try to get one.
1339 if (assigned_tree_.MaxLevel() == 0) return true;
1340 if (restart_policy_->NumRestarts() <
1341 parameters_->shared_tree_worker_min_restarts_per_subtree() ||
1342 time_limit_->GetElapsedDeterministicTime() <
1343 earliest_replacement_dtime_) {
1344 return false;
1345 }
1346 return assigned_tree_lbds_.WindowAverage() <
1347 restart_policy_->LbdAverageSinceReset();
1348}
1349
1350bool SharedTreeWorker::SyncWithSharedTree() {
1351 DCHECK_EQ(trail_->CurrentDecisionLevel(), 0);
1352 DCHECK(CheckLratInvariants());
1353 manager_->SyncTree(assigned_tree_);
1354 assigned_tree_.NormalizeImplications();
1355 if (ShouldReplaceSubtree()) {
1356 ++num_trees_;
1357 VLOG(2) << parameters_->name() << " acquiring tree #" << num_trees_
1358 << " after " << restart_policy_->NumRestarts() << " restarts"
1359 << " prev depth: " << assigned_tree_.MaxLevel()
1360 << " target: " << assigned_tree_lbds_.WindowAverage()
1361 << " lbd: " << restart_policy_->LbdAverageSinceReset();
1362 if (parameters_->shared_tree_worker_enable_phase_sharing() &&
1363 // Only save the phase if we've done a non-trivial amount of work on
1364 // this subtree.
1365 FinishedMinRestarts() &&
1366 !decision_policy_->GetBestPartialAssignment().empty()) {
1367 assigned_tree_.ClearTargetPhase();
1368 for (Literal lit : decision_policy_->GetBestPartialAssignment()) {
1369 // Skip anything assigned at level 0.
1370 if (trail_->Assignment().LiteralIsAssigned(lit)) continue;
1371 // If `lit` was last assigned at a shared level, it is implied in the
1372 // tree, no need to share its phase.
1373 if (trail_->Info(lit.Variable()).level <= assigned_tree_.MaxLevel()) {
1374 continue;
1375 }
1376 // Only set the phase for booleans to avoid creating literals on other
1377 // workers.
1378 auto encoded = ProtoLiteral::EncodeLiteral(lit, mapping_);
1379 if (!encoded.has_value()) continue;
1380 if (!assigned_tree_.AddPhase(*encoded)) break;
1381 }
1382 }
1383 manager_->ReplaceTree(assigned_tree_);
1384 assigned_tree_.NormalizeImplications();
1385 assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset());
1386 restart_policy_->Reset();
1387 earliest_replacement_dtime_ = 0;
1388 if (assigned_tree_.MaxLevel() > 0) {
1389 next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() +
1390 parameters_->shared_tree_split_min_dtime();
1391 }
1392 if (parameters_->shared_tree_worker_enable_phase_sharing()) {
1393 VLOG(2) << "Importing phase of length: "
1394 << assigned_tree_.TargetPhase().size();
1395 decision_policy_->ClearBestPartialAssignment();
1396 for (const ProtoLiteral& lit : assigned_tree_.TargetPhase()) {
1397 decision_policy_->SetTargetPolarityIfUnassigned(DecodeDecision(lit));
1398 }
1399 decision_policy_->ResetActivitiesToFollowBestPartialAssignment();
1400 // This seems bizzare after just setting the best partial assignment,
1401 // but this makes phase sharing work even when there is no stable phase in
1402 // the restart strategy, and makes no real difference if there is, since
1403 // the first dive will still try to follow this assignment until the first
1404 // conflict regardless of the restart strategy.
1405 decision_policy_->ClearBestPartialAssignment();
1406 }
1407 }
1408 // If we commit to this subtree, keep it for at least 1s of dtime.
1409 // This allows us to replace obviously bad subtrees quickly, and not replace
1410 // too frequently overall.
1411 if (FinishedMinRestarts() && earliest_replacement_dtime_ >=
1412 time_limit_->GetElapsedDeterministicTime()) {
1413 earliest_replacement_dtime_ =
1414 time_limit_->GetElapsedDeterministicTime() + 1;
1415 // Treat this as reassigning the same tree.
1416 assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset());
1417 }
1418 VLOG(2) << "Assigned level: " << assigned_tree_.MaxLevel() << " "
1419 << parameters_->name();
1420 ClearAssignedTreeDecisionsAndImplications();
1421 for (int level = 1; level <= assigned_tree_.MaxLevel(); ++level) {
1422 assigned_tree_decisions_.push_back(
1423 DecodeDecision(assigned_tree_.Decision(level)));
1424 std::vector<std::pair<Literal, ClauseId>> implications;
1425 for (const ProtoLiteral& impl : assigned_tree_.Implications(level)) {
1426 const Literal lit = DecodeDecision(impl);
1427 implications.emplace_back(lit,
1428 ImportLratClauseForImplication(lit, level));
1429 }
1430 assigned_tree_implications_.push_back(std::move(implications));
1431 }
1432 DCHECK(CheckLratInvariants());
1433 return true;
1434}
1435
1437 const std::function<void()>& feasible_solution_observer) {
1438 // Inside GetAssociatedLiteral if a literal becomes fixed at level 0 during
1439 // Search, the code CHECKs it is at level 0 when decoding the literal, but
1440 // the fixed literals are cached, so we can create them now to avoid a
1441 // crash.
1442 sat_solver_->Backtrack(0);
1443 encoder_->GetTrueLiteral();
1444 encoder_->GetFalseLiteral();
1445 level_zero_callbacks_->callbacks.push_back(
1446 [this]() { return SyncWithSharedTree(); });
1447 const bool has_objective =
1448 objective_ != nullptr && objective_->objective_var != kNoIntegerVariable;
1449 while (!time_limit_->LimitReached()) {
1450 if (!sat_solver_->FinishPropagation()) {
1451 return sat_solver_->UnsatStatus();
1452 }
1453 if (heuristics_->restart_policies[heuristics_->policy_index]()) {
1454 heuristics_->policy_index = restart_policy_->NumRestarts() %
1455 heuristics_->decision_policies.size();
1456 sat_solver_->Backtrack(0);
1457 }
1458 if (!SyncWithLocalTrail()) return sat_solver_->UnsatStatus();
1459 LiteralIndex decision_index;
1460 if (!NextDecision(&decision_index)) continue;
1461 if (time_limit_->LimitReached()) return SatSolver::LIMIT_REACHED;
1462 if (decision_index == kNoLiteralIndex) {
1463 feasible_solution_observer();
1464 if (!has_objective) return SatSolver::FEASIBLE;
1465 const IntegerValue objective =
1466 integer_trail_->LowerBound(objective_->objective_var);
1467 sat_solver_->Backtrack(0);
1468 if (!integer_trail_->Enqueue(
1469 IntegerLiteral::LowerOrEqual(objective_->objective_var,
1470 objective - 1),
1471 {}, {})) {
1472 return SatSolver::INFEASIBLE;
1473 }
1474
1475 continue;
1476 }
1477 const Literal decision(decision_index);
1478 CHECK(!sat_solver_->Assignment().LiteralIsFalse(decision));
1479 CHECK(!sat_solver_->Assignment().LiteralIsTrue(decision));
1480 // The LRAT proofs assume that an assigned tree decision is the actual one
1481 // which is taken here.
1482 if (!helper_->TakeDecision(
1483 decision, /*use_representative=*/lrat_proof_handler_ == nullptr)) {
1484 return sat_solver_->UnsatStatus();
1485 }
1486 MaybeProposeSplits();
1487 }
1488
1490}
1491
1492Literal SharedTreeWorker::DecodeDecision(ProtoLiteral lit) {
1493 return lit.Decode(mapping_, encoder_);
1494}
1495
1496std::optional<ProtoLiteral> SharedTreeWorker::EncodeDecision(Literal decision) {
1497 return ProtoLiteral::Encode(decision, mapping_, encoder_);
1498}
1499
1500bool SharedTreeWorker::CheckLratInvariants() {
1501 if (lrat_proof_handler_ != nullptr &&
1502 lrat_proof_handler_->lrat_check_enabled()) {
1503 for (int level = 0; level < assigned_tree_decisions_.size(); ++level) {
1504 for (auto& [lit, id] : assigned_tree_implications_[level]) {
1505 std::vector<Literal>& expected = DecisionReason(level + 1);
1506 expected.push_back(lit);
1507 CheckEqual(lrat_proof_handler_->GetLratClauseForDebug(id), expected);
1508 }
1509 }
1510 }
1511 return true;
1512}
1513
1514} // namespace operations_research::sat
void AppendClauseIdsFixing(absl::Span< const Literal > literals, std::vector< ClauseId > *clause_ids, LiteralIndex decision=kNoLiteralIndex, std::optional< absl::FunctionRef< ClauseId(int, int)> > root_literals={})
Definition clause.cc:756
int GetProtoVariableFromBooleanVariable(BooleanVariable var) const
IntegerVariable Integer(int ref) const
const InlinedIntegerLiteralVector & GetIntegerLiterals(Literal lit) const
Definition integer.h:229
Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit)
Definition integer.cc:274
LiteralIndex Index() const
Definition sat_base.h:92
BooleanVariable Variable() const
Definition sat_base.h:88
absl::Span< const Literal > GetLratClauseForDebug(ClauseId id) const
static std::optional< ProtoLiteral > Encode(Literal, CpModelMapping *, IntegerEncoder *)
Literal Decode(CpModelMapping *, IntegerEncoder *) const
static std::optional< ProtoLiteral > EncodeLiteral(Literal, CpModelMapping *)
std::vector< ProtoLiteral > TakeTargetPhase()
void SetTargetPhase(std::vector< ProtoLiteral > phase)
ProtoLiteral Decision(int level) const
void SetObjectiveLb(int level, IntegerValue objective_lb)
absl::Span< const ProtoLiteral > Implications(int level) const
absl::Span< const int > NodeIds(int level) const
IntegerValue ObjectiveLb(int level) const
void PushLevel(const ProtoLiteral &decision, IntegerValue objective_lb, int node_id)
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_DISCREPANCY
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_BALANCED_TREE
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_AUTO
static constexpr SharedTreeSplitStrategy SPLIT_STRATEGY_OBJECTIVE_LB
const VariablesAssignment & Assignment() const
Definition sat_solver.h:420
void CloseTree(ProtoTrail &path, int level)
bool SyncTree(ProtoTrail &path) ABSL_LOCKS_EXCLUDED(mu_)
int NumNodes() const ABSL_LOCKS_EXCLUDED(mu_)
int TrySplitTree(absl::Span< const ProtoLiteral > decisions, ProtoTrail &path) ABSL_LOCKS_EXCLUDED(mu_)
SatSolver::Status Search(const std::function< void()> &feasible_solution_observer)
bool LiteralIsFalse(Literal literal) const
Definition sat_base.h:203
bool LiteralIsTrue(Literal literal) const
Definition sat_base.h:206
constexpr IntegerValue kMaxIntegerValue(std::numeric_limits< IntegerValue::ValueType >::max() - 1)
const LiteralIndex kNoLiteralIndex(-1)
const IntegerVariable kNoIntegerVariable(-1)
IntegerVariable PositiveVariable(IntegerVariable i)
constexpr ClauseId kNoClauseId(0)
ClosedInterval::Iterator end(ClosedInterval interval)
ClosedInterval::Iterator begin(ClosedInterval interval)
STL namespace.
static IntegerLiteral GreaterOrEqual(IntegerVariable i, IntegerValue bound)
static IntegerLiteral LowerOrEqual(IntegerVariable i, IntegerValue bound)