Google OR-Tools v9.12
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
work_assignment.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
15
16#include <algorithm>
17#include <array>
18#include <cmath>
19#include <deque>
20#include <functional>
21#include <limits>
22#include <memory>
23#include <optional>
24#include <string>
25#include <tuple>
26#include <utility>
27#include <vector>
28
29#include "absl/container/flat_hash_set.h"
30#include "absl/log/check.h"
31#include "absl/strings/str_cat.h"
32#include "absl/synchronization/mutex.h"
33#include "absl/types/span.h"
37#include "ortools/sat/integer.h"
40#include "ortools/sat/model.h"
41#include "ortools/sat/restart.h"
44#include "ortools/sat/sat_parameters.pb.h"
47#include "ortools/sat/util.h"
50
52namespace {
53
54// We restart the shared tree 10 times after 2 restarts per worker. After that
55// we restart when the tree reaches the maximum allowable number of nodes, but
56// still at most once per 2 restarts per worker.
57const int kSyncsPerWorkerPerRestart = 2;
58const int kNumInitialRestarts = 10;
59
60// If you build a tree by expanding the nodes with minimal depth+discrepancy,
61// the number of leaves when all nodes with a given value have been split
62// follows the fibonacci sequence:
63// num_leaves(0) := 2;
64// num_leaves(1) := 3;
65// num_leaves(n) := num_leaves(n-1) + num_leaves(n-2)
66// This function returns f(n) := min({i | num_leaves(i) >= n})
67int MaxAllowedDiscrepancyPlusDepth(int num_leaves) {
68 int i = 0;
69 int a = 1;
70 int b = 2;
71 while (b < num_leaves) {
72 std::tie(a, b) = std::make_pair(b, a + b);
73 ++i;
74 }
75 return i;
76}
77} // namespace
78
80 IntegerEncoder* encoder) const {
81 DCHECK_LT(proto_var_, mapping->NumProtoVariables());
82 if (mapping->IsBoolean(proto_var_)) {
83 return mapping->Literal(proto_var_);
84 }
85 return encoder->GetOrCreateAssociatedLiteral(DecodeInteger(mapping));
86}
87
88IntegerLiteral ProtoLiteral::DecodeInteger(CpModelMapping* mapping) const {
89 const int positive_var = PositiveRef(proto_var_);
90 if (!mapping->IsInteger(positive_var)) {
91 return IntegerLiteral();
92 }
93 if (proto_var_ < 0) {
94 return IntegerLiteral::LowerOrEqual(mapping->Integer(positive_var), -lb_);
95 }
96 return IntegerLiteral::GreaterOrEqual(mapping->Integer(positive_var), lb_);
97}
98
99std::optional<ProtoLiteral> ProtoLiteral::EncodeInteger(
100 IntegerLiteral literal, CpModelMapping* mapping) {
101 IntegerVariable positive_var = PositiveVariable(literal.var);
102 const int model_var =
103 mapping->GetProtoVariableFromIntegerVariable(positive_var);
104 if (model_var == -1) {
105 return std::nullopt;
106 }
107 ProtoLiteral result{
108 literal.var == positive_var ? model_var : NegatedRef(model_var),
109 literal.bound};
110 DCHECK_EQ(result.DecodeInteger(mapping), literal);
111 DCHECK_EQ(result.Negated().DecodeInteger(mapping), literal.Negated());
112 return result;
113}
114std::optional<ProtoLiteral> ProtoLiteral::Encode(Literal literal,
115 CpModelMapping* mapping,
116 IntegerEncoder* encoder) {
117 const std::optional<ProtoLiteral> result = EncodeLiteral(literal, mapping);
118 if (result.has_value()) return result;
119
120 for (auto int_lit : encoder->GetIntegerLiterals(literal)) {
121 auto result = EncodeInteger(int_lit, mapping);
122 if (result.has_value()) {
123 DCHECK_EQ(result->DecodeInteger(mapping), int_lit);
124 DCHECK_EQ(result->Negated().DecodeInteger(mapping), int_lit.Negated());
125 return result;
126 }
127 }
128 return std::nullopt;
129}
130
131std::optional<ProtoLiteral> ProtoLiteral::EncodeLiteral(
132 Literal literal, CpModelMapping* mapping) {
133 if (literal.Index() == kNoLiteralIndex) {
134 return std::nullopt;
135 }
136 int model_var =
138 if (model_var == -1) {
139 return std::nullopt;
140 }
141 DCHECK(mapping->IsBoolean(model_var));
142 ProtoLiteral result{literal.IsPositive() ? model_var : NegatedRef(model_var),
143 literal.IsPositive() ? 1 : 0};
144 return result;
145}
146
147ProtoTrail::ProtoTrail() { target_phase_.reserve(kMaxPhaseSize); }
148
150 IntegerValue objective_lb, int node_id) {
151 CHECK_GT(node_id, 0);
152 decision_indexes_.push_back(literals_.size());
153 literals_.push_back(decision);
154 node_ids_.push_back(node_id);
155 implications_.push_back({});
156 if (!level_to_objective_lbs_.empty()) {
157 objective_lb = std::max(level_to_objective_lbs_.back(), objective_lb);
158 }
159 level_to_objective_lbs_.push_back(objective_lb);
160}
161
163 DCHECK_GE(level, 1);
164 DCHECK_LE(level, decision_indexes_.size());
165 DCHECK_LE(level, implications_.size());
166 SetObjectiveLb(level - 1, ObjectiveLb(level));
167 const ProtoLiteral decision = Decision(level);
168 implication_level_[decision] = level - 1;
169 // We don't store implications for level 0, so only move implications up to
170 // the parent if we are removing level 2 or greater.
171 if (level >= 2) {
172 MutableImplications(level - 1).push_back(decision);
173 }
174 for (const ProtoLiteral& implication : Implications(level)) {
175 implication_level_[implication] = level - 1;
176 if (level >= 2) {
177 MutableImplications(level - 1).push_back(implication);
178 }
179 }
180 // implications_[level-1] stores the implications for level, which are now
181 // stored in the parent's implications, so we can delete them.
182 implications_.erase(implications_.begin() + level - 1);
183 decision_indexes_.erase(decision_indexes_.begin() + level - 1);
184 level_to_objective_lbs_.erase(level_to_objective_lbs_.begin() + level - 1);
185}
186
188 decision_indexes_.clear();
189 literals_.clear();
190 level_to_objective_lbs_.clear();
191 node_ids_.clear();
192 target_phase_.clear();
193 implication_level_.clear();
194 implications_.clear();
195}
196
197void ProtoTrail::SetObjectiveLb(int level, IntegerValue objective_lb) {
198 if (level == 0) return;
199 level_to_objective_lbs_[level - 1] =
200 std::max(objective_lb, level_to_objective_lbs_[level - 1]);
201}
202
203absl::Span<const int> ProtoTrail::NodeIds(int level) const {
204 DCHECK_LE(level, decision_indexes_.size());
205 int start = level == 0 ? 0 : decision_indexes_[level - 1];
206 int end = level == decision_indexes_.size() ? node_ids_.size()
207 : decision_indexes_[level];
208 return absl::MakeSpan(node_ids_.data() + start, end - start);
209}
210
211absl::Span<const ProtoLiteral> ProtoTrail::Implications(int level) const {
212 if (level > implications_.size() || level <= 0) {
213 return absl::MakeSpan(literals_.data(), 0);
214 }
215 return absl::MakeSpan(implications_[level - 1]);
216}
217
219 : params_(*model->GetOrCreate<SatParameters>()),
220 num_workers_(params_.shared_tree_num_workers()),
221 shared_response_manager_(model->GetOrCreate<SharedResponseManager>()),
222 num_splits_wanted_(
223 num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1),
224 max_nodes_(
225 params_.shared_tree_max_nodes_per_worker() >=
226 std::numeric_limits<int>::max() / std::max(num_workers_, 1)
227 ? std::numeric_limits<int>::max()
228 : num_workers_ * params_.shared_tree_max_nodes_per_worker()) {
229 CHECK_GE(num_workers_, 0);
230 // Create the root node with a fake literal.
231 nodes_.push_back(
232 {.literal = ProtoLiteral(),
233 .objective_lb = shared_response_manager_->GetInnerObjectiveLowerBound(),
234 .trail_info = std::make_unique<NodeTrailInfo>()});
235 unassigned_leaves_.reserve(num_workers_);
236 unassigned_leaves_.push_back(&nodes_.back());
237}
238
240 absl::MutexLock mutex_lock(&mu_);
241 return nodes_.size();
242}
243
245 absl::MutexLock mutex_lock(&mu_);
246 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
247 if (!IsValid(path)) {
248 path.Clear();
249 return false;
250 }
251 // We don't rely on these being empty, but we expect them to be.
252 DCHECK(to_close_.empty());
253 DCHECK(to_update_.empty());
254 int prev_level = -1;
255 for (const auto& [node, level] : nodes) {
256 if (level == prev_level) {
257 to_close_.push_back(GetSibling(node));
258 } else if (level > 0 && node->objective_lb < path.ObjectiveLb(level)) {
259 node->objective_lb = path.ObjectiveLb(level);
260 to_update_.push_back(node->parent);
261 }
262 if (level > 0 && !node->closed) {
263 NodeTrailInfo* trail_info = GetTrailInfo(node);
264 for (const ProtoLiteral& implication : path.Implications(level)) {
265 auto it = trail_info->implications
266 .emplace(implication.proto_var(), implication.lb())
267 .first;
268 if (it->second < implication.lb()) {
269 it->second = implication.lb();
270 }
271 }
272 }
273 prev_level = level;
274 }
275 ProcessNodeChanges();
276 if (nodes.back().first->closed) {
277 path.Clear();
278 return false;
279 }
280 // Restart after processing updates - we might learn a new objective bound.
281 if (++num_syncs_since_restart_ / num_workers_ > kSyncsPerWorkerPerRestart &&
282 num_restarts_ < kNumInitialRestarts) {
283 RestartLockHeld();
284 path.Clear();
285 return false;
286 }
287 // Sync lower bounds and implications from the shared tree to `path`.
288 AssignLeaf(path, nodes.back().first);
289 return true;
290}
291
293 absl::MutexLock mutex_lock(&mu_);
294 if (!IsValid(path)) return;
295 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
296 if (nodes.back().first->closed) {
297 VLOG(2) << "Cannot split closed node";
298 return;
299 }
300 if (nodes.back().first->children[0] != nullptr) {
301 LOG_IF(WARNING, nodes.size() > 1)
302 << "Cannot resplit previously split node @ " << nodes.back().second
303 << "/" << nodes.size();
304 return;
305 }
306 if (nodes_.size() + 2 > max_nodes_) {
307 VLOG(2) << "Too many nodes to accept split";
308 return;
309 }
310 if (num_splits_wanted_ <= 0) {
311 VLOG(2) << "Enough splits for now";
312 return;
313 }
314 const int num_desired_leaves =
315 params_.shared_tree_open_leaves_per_worker() * num_workers_;
316 if (params_.shared_tree_split_strategy() ==
317 SatParameters::SPLIT_STRATEGY_DISCREPANCY ||
318 params_.shared_tree_split_strategy() ==
319 SatParameters::SPLIT_STRATEGY_AUTO) {
320 int discrepancy = 0;
321 for (const auto& [node, level] : nodes) {
322 if (node->parent == nullptr || node->implied) continue;
323 IntegerValue sibling_bound = GetSibling(node)->objective_lb;
324 discrepancy += (node->objective_lb == sibling_bound
325 ? node != node->parent->children[0]
326 : node->objective_lb > sibling_bound);
327 }
328 // TODO(user): Need to write up the shape this creates.
329 // This rule will allow twice as many leaves in the preferred subtree.
330 if (discrepancy + path.MaxLevel() >
331 MaxAllowedDiscrepancyPlusDepth(num_desired_leaves) +
332 params_.shared_tree_balance_tolerance()) {
333 VLOG(2) << "Too high discrepancy to accept split";
334 return;
335 }
336 } else if (params_.shared_tree_split_strategy() ==
337 SatParameters::SPLIT_STRATEGY_OBJECTIVE_LB) {
338 if (nodes.back().first->objective_lb > nodes.front().first->objective_lb) {
339 VLOG(2) << "Can only split nodes with minimum objective lb, "
340 << nodes.back().first->objective_lb << " > "
341 << nodes.front().first->objective_lb;
342 return;
343 }
344 } else if (params_.shared_tree_split_strategy() ==
345 SatParameters::SPLIT_STRATEGY_BALANCED_TREE) {
346 if (path.MaxLevel() + 1 >
347 log2(num_desired_leaves) + params_.shared_tree_balance_tolerance()) {
348 VLOG(2) << "Tree too unbalanced to accept split";
349 return;
350 }
351 }
352 VLOG_EVERY_N(2, 10) << unassigned_leaves_.size() << " unassigned leaves, "
353 << nodes_.size() << " subtrees, " << num_splits_wanted_
354 << " splits wanted";
355 Split(nodes, decision);
356 auto [new_leaf, level] = nodes.back();
357 path.PushLevel(new_leaf->literal, new_leaf->objective_lb, new_leaf->id);
358}
359
361 absl::MutexLock mutex_lock(&mu_);
362 std::vector<std::pair<Node*, int>> nodes = GetAssignedNodes(path);
363 if (nodes.back().first->children[0] == nullptr &&
364 !nodes.back().first->closed && nodes.size() > 1) {
365 Node* leaf = nodes.back().first;
366 VLOG(2) << "Returning leaf to be replaced";
367 GetTrailInfo(leaf)->phase.assign(path.TargetPhase().begin(),
368 path.TargetPhase().end());
369 unassigned_leaves_.push_back(leaf);
370 }
371 path.Clear();
372 while (!unassigned_leaves_.empty()) {
373 const int i = num_leaves_assigned_++ % unassigned_leaves_.size();
374 std::swap(unassigned_leaves_[i], unassigned_leaves_.back());
375 Node* leaf = unassigned_leaves_.back();
376 unassigned_leaves_.pop_back();
377 if (!leaf->closed && leaf->children[0] == nullptr) {
378 AssignLeaf(path, leaf);
379 path.SetTargetPhase(GetTrailInfo(leaf)->phase);
380 return;
381 }
382 }
383 VLOG(2) << "Assigning root because no unassigned leaves are available";
384 // TODO(user): Investigate assigning a random leaf so workers can still
385 // improve shared tree bounds.
386}
387
388SharedTreeManager::NodeTrailInfo* SharedTreeManager::GetTrailInfo(Node* node) {
389 CHECK(node != nullptr && !node->closed);
390 while (node->trail_info == nullptr) {
391 node = node->parent;
392 }
393 CHECK_NE(node, nullptr);
394 return node->trail_info.get();
395}
396
397SharedTreeManager::Node* SharedTreeManager::GetSibling(Node* node) {
398 if (node == nullptr || node->parent == nullptr) return nullptr;
399 if (node->parent->children[0] != node) {
400 return node->parent->children[0];
401 }
402 return node->parent->children[1];
403}
404
405void SharedTreeManager::Split(std::vector<std::pair<Node*, int>>& nodes,
406 ProtoLiteral lit) {
407 const auto [parent, level] = nodes.back();
408 DCHECK(parent->children[0] == nullptr);
409 DCHECK(parent->children[1] == nullptr);
410 parent->children[0] = MakeSubtree(parent, lit);
411 parent->children[1] = MakeSubtree(parent, lit.Negated());
412 NodeTrailInfo* trail_info = GetTrailInfo(parent);
413 if (trail_info != nullptr) {
414 parent->children[0]->trail_info = std::make_unique<NodeTrailInfo>(
415 NodeTrailInfo{.phase = trail_info->phase});
416 parent->children[1]->trail_info = std::make_unique<NodeTrailInfo>(
417 NodeTrailInfo{.phase = std::move(trail_info->phase)});
418 }
419 nodes.push_back(std::make_pair(parent->children[0], level + 1));
420 unassigned_leaves_.push_back(parent->children[1]);
421 --num_splits_wanted_;
422}
423
424SharedTreeManager::Node* SharedTreeManager::MakeSubtree(Node* parent,
425 ProtoLiteral literal) {
426 nodes_.push_back(
427 Node{.literal = literal,
428 .objective_lb = parent->objective_lb,
429 .parent = parent,
430 .id = static_cast<int>(nodes_.size() + node_id_offset_)});
431 return &nodes_.back();
432}
433
434void SharedTreeManager::ProcessNodeChanges() {
435 int num_newly_closed = 0;
436 while (!to_close_.empty()) {
437 Node* node = to_close_.back();
438 CHECK_NE(node, nullptr);
439 to_close_.pop_back();
440 // Iterate over open parents while each sibling is closed.
441 while (node != nullptr && !node->closed) {
442 ++num_newly_closed;
443 ++num_closed_nodes_;
444 node->closed = true;
445 // Keep the root trail_info so GetTrailInfo never returns nullptr.
446 if (node->parent != nullptr) node->trail_info.reset();
447 node->objective_lb = kMaxIntegerValue;
448 // If we are closing a leaf, try to maintain the same number of leaves;
449 num_splits_wanted_ += (node->children[0] == nullptr);
450 for (Node* child : node->children) {
451 if (child == nullptr || child->closed) continue;
452 to_close_.push_back(child);
453 }
454 Node* sibling = GetSibling(node);
455 if (sibling != nullptr) {
456 sibling->implied = true;
457 if (!sibling->closed) {
458 break;
459 }
460 }
461 node = node->parent;
462 }
463 DCHECK(node == nullptr || node->closed);
464 if (node == nullptr) {
465 shared_response_manager_->NotifyThatImprovingProblemIsInfeasible(
466 ShortStatus());
467 } else if (node->parent != nullptr) {
468 to_update_.push_back(node->parent);
469 }
470 }
471 if (num_newly_closed > 0) {
472 shared_response_manager_->LogMessageWithThrottling(
473 "Tree", absl::StrCat("nodes:", nodes_.size(), "/", max_nodes_,
474 " closed:", num_closed_nodes_,
475 " unassigned:", unassigned_leaves_.size(),
476 " restarts:", num_restarts_));
477 }
478 // TODO(user): We could do resolution here by moving implications that
479 // are true in each child to the parent.
480 bool root_updated = false;
481 while (!to_update_.empty()) {
482 Node* node = to_update_.back();
483 to_update_.pop_back();
484 // Iterate over parents while the lower bound can be improved.
485 while (node != nullptr && !node->closed) {
486 DCHECK(node->children[0] != nullptr);
487 DCHECK(node->children[1] != nullptr);
488 NodeTrailInfo* trail_info = GetTrailInfo(node);
489 for (Node* child : node->children) {
490 if (child->implied && child->trail_info != nullptr) {
491 trail_info->implications.merge(child->trail_info->implications);
492 child->trail_info.reset();
493 }
494 }
495 IntegerValue child_bound = std::min(node->children[0]->objective_lb,
496 node->children[1]->objective_lb);
497 if (child_bound <= node->objective_lb) break;
498 node->objective_lb = child_bound;
499 node = node->parent;
500 }
501 if (node == nullptr) root_updated = true;
502 }
503 if (root_updated) {
504 shared_response_manager_->UpdateInnerObjectiveBounds(
505 ShortStatus(), nodes_[0].objective_lb, kMaxIntegerValue);
506 }
507 // These are shared via SharedBoundsManager, don't duplicate here.
508 nodes_[0].trail_info->implications.clear();
509}
510
511std::vector<std::pair<SharedTreeManager::Node*, int>>
512SharedTreeManager::GetAssignedNodes(const ProtoTrail& path) {
513 std::vector<std::pair<Node*, int>> nodes({std::make_pair(&nodes_[0], 0)});
514 if (!IsValid(path)) {
515 // Restart has happened, nodes in this path are no longer valid, but the
516 // root is equivalent.
517 return nodes;
518 }
519 for (int i = 0; i <= path.MaxLevel(); ++i) {
520 for (int id : path.NodeIds(i)) {
521 const int index = id - node_id_offset_;
522 CHECK_GE(index, 0) << " in path.NodeIds(" << i
523 << "), max_level=" << path.MaxLevel();
524 CHECK_LT(index, nodes_.size());
525 DCHECK_EQ(nodes.back().first, nodes_[index].parent);
526 nodes.push_back(std::make_pair(&nodes_[index], i));
527 }
528 }
529 return nodes;
530}
531
533 absl::MutexLock mutex_lock(&mu_);
534 const int node_id_to_close = path.NodeIds(level).front();
535 path.Clear();
536 if (node_id_to_close < node_id_offset_) return;
537 Node* node = &nodes_[node_id_to_close - node_id_offset_];
538 VLOG(2) << "Closing subtree at level " << level;
539 DCHECK(to_close_.empty());
540 to_close_.push_back(node);
541 ProcessNodeChanges();
542}
543
544void SharedTreeManager::AssignLeaf(ProtoTrail& path, Node* leaf) {
545 path.Clear();
546 std::vector<Node*> reversed_path;
547 while (leaf != &nodes_[0]) {
548 reversed_path.push_back(&nodes_[leaf->id - node_id_offset_]);
549 leaf = leaf->parent;
550 }
551 while (!reversed_path.empty()) {
552 Node* leaf = reversed_path.back();
553 reversed_path.pop_back();
554 path.PushLevel(leaf->literal, leaf->objective_lb, leaf->id);
555 if (leaf->implied) {
556 path.SetLevelImplied(path.MaxLevel());
557 }
558 if (params_.shared_tree_worker_enable_trail_sharing() &&
559 leaf->trail_info != nullptr) {
560 for (const auto& [var, lb] : leaf->trail_info->implications) {
561 path.AddImplication(path.MaxLevel(), ProtoLiteral(var, lb));
562 }
563 }
564 }
565}
566
567bool SharedTreeManager::IsValid(const ProtoTrail& path) const {
568 auto node_ids = path.NodeIds(path.MaxLevel());
569 if (node_ids.empty()) return true;
570 if (node_ids.back() < node_id_offset_) return false;
571 return true;
572}
573
574void SharedTreeManager::RestartLockHeld() {
575 node_id_offset_ += nodes_.size();
576 nodes_.resize(1);
577 nodes_[0].id = node_id_offset_;
578 nodes_[0].children = {nullptr, nullptr};
579 unassigned_leaves_.clear();
580 num_splits_wanted_ =
581 num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1;
582 num_closed_nodes_ = 0;
583 num_restarts_ += 1;
584 num_syncs_since_restart_ = 0;
585}
586
587std::string SharedTreeManager::ShortStatus() const {
588 return absl::StrCat("shared_tree_manager(r=", num_restarts_,
589 " n=", nodes_.size(), ")");
590}
591
593 : parameters_(model->GetOrCreate<SatParameters>()),
594 shared_response_(model->GetOrCreate<SharedResponseManager>()),
595 time_limit_(model->GetOrCreate<TimeLimit>()),
596 manager_(model->GetOrCreate<SharedTreeManager>()),
597 mapping_(model->GetOrCreate<CpModelMapping>()),
598 sat_solver_(model->GetOrCreate<SatSolver>()),
599 trail_(model->GetOrCreate<Trail>()),
600 integer_trail_(model->GetOrCreate<IntegerTrail>()),
601 encoder_(model->GetOrCreate<IntegerEncoder>()),
602 objective_(model->Get<ObjectiveDefinition>()),
603 random_(model->GetOrCreate<ModelRandomGenerator>()),
604 helper_(model->GetOrCreate<IntegerSearchHelper>()),
605 heuristics_(model->GetOrCreate<SearchHeuristics>()),
606 decision_policy_(model->GetOrCreate<SatDecisionPolicy>()),
607 restart_policy_(model->GetOrCreate<RestartPolicy>()),
608 level_zero_callbacks_(model->GetOrCreate<LevelZeroCallbackHelper>()),
609 reversible_int_repository_(model->GetOrCreate<RevIntRepository>()),
610 assigned_tree_lbds_(/*window_size=*/8) {}
611
612const std::vector<Literal>& SharedTreeWorker::DecisionReason(int level) {
613 CHECK_LE(level, assigned_tree_literals_.size());
614 reason_.clear();
615 for (int i = 0; i < level; ++i) {
616 reason_.push_back(assigned_tree_literals_[i].Negated());
617 }
618 return reason_;
619}
620
621bool SharedTreeWorker::AddDecisionImplication(Literal lit, int level) {
622 CHECK_NE(lit.Index(), kNoLiteralIndex);
623 CHECK(!sat_solver_->Assignment().LiteralIsTrue(lit));
624 if (sat_solver_->Assignment().LiteralIsFalse(lit)) {
625 VLOG(2) << "Closing subtree via impl at " << level + 1
626 << " assigned=" << assigned_tree_.MaxLevel();
627 integer_trail_->ReportConflict(DecisionReason(level), {});
628 manager_->CloseTree(assigned_tree_, level);
629 assigned_tree_literals_.clear();
630 return false;
631 }
632 integer_trail_->EnqueueLiteral(lit, DecisionReason(level), {});
633 VLOG(2) << "Learned shared clause";
634 return true;
635}
636
637bool SharedTreeWorker::AddImplications() {
638 const int level = sat_solver_->CurrentDecisionLevel();
639 // Level 0 implications are unit clauses and are synced elsewhere.
640 if (level == 0) return false;
641 if (level > assigned_tree_.MaxLevel()) {
642 return false;
643 }
644 rev_num_processed_implications_.resize(level + 1, 0);
645 auto& num_processed_implications = rev_num_processed_implications_[level];
646 reversible_int_repository_->SaveState(&num_processed_implications);
647 absl::Span<const Literal> implied_literals =
648 absl::MakeConstSpan(assigned_tree_implications_[level - 1])
649 .subspan(num_processed_implications);
650 bool added_clause = false;
651 for (Literal impl : implied_literals) {
652 ++num_processed_implications;
653 if (sat_solver_->Assignment().LiteralIsTrue(impl)) continue;
654 added_clause = true;
655 if (!AddDecisionImplication(impl, level)) return true;
656 }
657 if (objective_ != nullptr &&
658 objective_->objective_var != kNoIntegerVariable) {
659 const IntegerValue obj_lb =
660 integer_trail_->LowerBound(objective_->objective_var);
661 assigned_tree_.SetObjectiveLb(level, obj_lb);
662 const Literal obj_lit =
663 encoder_->GetOrCreateAssociatedLiteral(IntegerLiteral::GreaterOrEqual(
664 objective_->objective_var, assigned_tree_.ObjectiveLb(level)));
665 if (!sat_solver_->Assignment().LiteralIsTrue(obj_lit)) {
666 AddDecisionImplication(obj_lit, level);
667 return true;
668 }
669 }
670 return added_clause;
671}
672
673bool SharedTreeWorker::SyncWithLocalTrail() {
674 while (true) {
675 if (!sat_solver_->FinishPropagation()) return false;
676 // Ensure we are at fixed point w.r.t. implications in the tree up to the
677 // current level.
678 if (AddImplications()) continue;
679
680 if (!helper_->BeforeTakingDecision()) return false;
681 const int level = sat_solver_->CurrentDecisionLevel();
682 if (parameters_->shared_tree_worker_enable_trail_sharing() && level > 0 &&
683 level <= assigned_tree_.MaxLevel()) {
684 // Add implications from the local trail to share with other workers.
685 reversible_int_repository_->SaveState(&reversible_trail_index_);
686 for (int i = trail_->Index() - 1; i >= reversible_trail_index_; --i) {
687 const Literal lit = (*trail_)[i];
688 if (trail_->AssignmentType(lit.Variable()) ==
690 break;
691 }
692 std::optional<ProtoLiteral> encoded = EncodeDecision(lit);
693 if (!encoded.has_value()) continue;
694 assigned_tree_.AddImplication(level, *encoded);
695 }
696 reversible_trail_index_ = trail_->Index();
697 }
698 if (level >= assigned_tree_.MaxLevel()) break;
699 // The next decision is assigned, make sure it makes sense.
700 const Literal next_decision = assigned_tree_literals_[level];
701 if (!sat_solver_->Assignment().LiteralIsAssigned(next_decision)) break;
702 if (sat_solver_->Assignment().LiteralIsFalse(next_decision)) {
703 // Next assigned decision is impossible.
704 VLOG(2) << "Closing subtree at " << level + 1
705 << " assigned=" << assigned_tree_.MaxLevel();
706 manager_->CloseTree(assigned_tree_, level + 1);
707 assigned_tree_literals_.clear();
708 assigned_tree_implications_.clear();
709 sat_solver_->Backtrack(0);
710 } else {
711 // The next level is implied by the current one.
712 assigned_tree_.SetLevelImplied(level + 1);
713 if (level > 0) {
714 assigned_tree_implications_[level - 1].insert(
715 assigned_tree_implications_[level - 1].end(),
716 assigned_tree_implications_[level].begin(),
717 assigned_tree_implications_[level].end());
718 }
719 assigned_tree_implications_.erase(assigned_tree_implications_.begin() +
720 level);
721 assigned_tree_literals_.erase(assigned_tree_literals_.begin() + level);
722 }
723 }
724 return true;
725}
726
727bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) {
728 const auto& decision_policy =
729 heuristics_->decision_policies[heuristics_->policy_index];
730 const int next_level = sat_solver_->CurrentDecisionLevel() + 1;
731 new_split_available_ = next_level == assigned_tree_.MaxLevel() + 1;
732
733 CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel());
734 if (next_level <= assigned_tree_.MaxLevel()) {
735 VLOG(2) << "Following shared trail depth=" << next_level << " "
736 << parameters_->name();
737 const Literal decision = assigned_tree_literals_[next_level - 1];
738 CHECK(!sat_solver_->Assignment().LiteralIsFalse(decision))
739 << " at depth " << next_level << " " << parameters_->name();
740 CHECK(!sat_solver_->Assignment().LiteralIsTrue(decision));
741 *decision_index = decision.Index();
742 return true;
743 }
744 return helper_->GetDecision(decision_policy, decision_index);
745}
746
747void SharedTreeWorker::MaybeProposeSplit() {
748 if (!new_split_available_ ||
749 sat_solver_->CurrentDecisionLevel() != assigned_tree_.MaxLevel() + 1) {
750 return;
751 }
752 new_split_available_ = false;
753 const Literal split_decision =
754 sat_solver_->Decisions()[assigned_tree_.MaxLevel()].literal;
755 const std::optional<ProtoLiteral> encoded = EncodeDecision(split_decision);
756 if (encoded.has_value()) {
757 CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel());
758 manager_->ProposeSplit(assigned_tree_, *encoded);
759 if (assigned_tree_.MaxLevel() > assigned_tree_literals_.size()) {
760 assigned_tree_literals_.push_back(split_decision);
761 assigned_tree_implications_.push_back({});
762 }
763 CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel());
764 }
765}
766
767bool SharedTreeWorker::ShouldReplaceSubtree() {
768 // If we have no assignment, try to get one.
769 if (assigned_tree_.MaxLevel() == 0) return true;
770 if (restart_policy_->NumRestarts() <
771 parameters_->shared_tree_worker_min_restarts_per_subtree() ||
772 time_limit_->GetElapsedDeterministicTime() <
773 earliest_replacement_dtime_) {
774 return false;
775 }
776 return assigned_tree_lbds_.WindowAverage() <
777 restart_policy_->LbdAverageSinceReset();
778}
779
780bool SharedTreeWorker::SyncWithSharedTree() {
781 manager_->SyncTree(assigned_tree_);
782 if (ShouldReplaceSubtree()) {
783 ++num_trees_;
784 VLOG(2) << parameters_->name() << " acquiring tree #" << num_trees_
785 << " after " << restart_policy_->NumRestarts() << " restarts"
786 << " prev depth: " << assigned_tree_.MaxLevel()
787 << " target: " << assigned_tree_lbds_.WindowAverage()
788 << " lbd: " << restart_policy_->LbdAverageSinceReset();
789 if (parameters_->shared_tree_worker_enable_phase_sharing() &&
790 // Only save the phase if we've done a non-trivial amount of work on
791 // this subtree.
792 FinishedMinRestarts() &&
793 !decision_policy_->GetBestPartialAssignment().empty()) {
794 assigned_tree_.ClearTargetPhase();
795 for (Literal lit : decision_policy_->GetBestPartialAssignment()) {
796 // Only set the phase for booleans to avoid creating literals on other
797 // workers.
798 auto encoded = ProtoLiteral::EncodeLiteral(lit, mapping_);
799 if (!encoded.has_value()) continue;
800 if (!assigned_tree_.AddPhase(*encoded)) break;
801 }
802 }
803 manager_->ReplaceTree(assigned_tree_);
804 assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset());
805 restart_policy_->Reset();
806 earliest_replacement_dtime_ = 0;
807 if (parameters_->shared_tree_worker_enable_phase_sharing()) {
808 VLOG(2) << "Importing phase of length: "
809 << assigned_tree_.TargetPhase().size();
810 decision_policy_->ClearBestPartialAssignment();
811 for (const ProtoLiteral& lit : assigned_tree_.TargetPhase()) {
812 decision_policy_->SetTargetPolarity(DecodeDecision(lit));
813 }
814 }
815 }
816 // If we commit to this subtree, keep it for at least 1s of dtime.
817 // This allows us to replace obviously bad subtrees quickly, and not replace
818 // too frequently overall.
819 if (FinishedMinRestarts() && earliest_replacement_dtime_ >=
820 time_limit_->GetElapsedDeterministicTime()) {
821 earliest_replacement_dtime_ =
822 time_limit_->GetElapsedDeterministicTime() + 1;
823 // Treat this as reassigning the same tree.
824 assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset());
825 }
826 VLOG(2) << "Assigned level: " << assigned_tree_.MaxLevel() << " "
827 << parameters_->name();
828 assigned_tree_literals_.clear();
829 assigned_tree_implications_.clear();
830 for (int i = 1; i <= assigned_tree_.MaxLevel(); ++i) {
831 assigned_tree_literals_.push_back(
832 DecodeDecision(assigned_tree_.Decision(i)));
833 std::vector<Literal> implications;
834 for (const ProtoLiteral& impl : assigned_tree_.Implications(i)) {
835 implications.push_back(DecodeDecision(impl));
836 }
837 assigned_tree_implications_.push_back(std::move(implications));
838 }
839 return true;
840}
841
843 const std::function<void()>& feasible_solution_observer) {
844 // Inside GetAssociatedLiteral if a literal becomes fixed at level 0 during
845 // Search, the code CHECKs it is at level 0 when decoding the literal, but
846 // the fixed literals are cached, so we can create them now to avoid a
847 // crash.
848 sat_solver_->Backtrack(0);
849 encoder_->GetTrueLiteral();
850 encoder_->GetFalseLiteral();
851 level_zero_callbacks_->callbacks.push_back(
852 [this]() { return SyncWithSharedTree(); });
853 const bool has_objective =
854 objective_ != nullptr && objective_->objective_var != kNoIntegerVariable;
855 while (!time_limit_->LimitReached()) {
856 if (!sat_solver_->FinishPropagation()) {
857 return sat_solver_->UnsatStatus();
858 }
859 if (heuristics_->restart_policies[heuristics_->policy_index]()) {
860 heuristics_->policy_index = restart_policy_->NumRestarts() %
861 heuristics_->decision_policies.size();
862 sat_solver_->Backtrack(0);
863 }
864 if (!SyncWithLocalTrail()) return sat_solver_->UnsatStatus();
865 LiteralIndex decision_index;
866 if (!NextDecision(&decision_index)) continue;
867 if (time_limit_->LimitReached()) return SatSolver::LIMIT_REACHED;
868 if (decision_index == kNoLiteralIndex) {
869 feasible_solution_observer();
870 if (!has_objective) return SatSolver::FEASIBLE;
871 const IntegerValue objective =
872 integer_trail_->LowerBound(objective_->objective_var);
873 sat_solver_->Backtrack(0);
874 if (!integer_trail_->Enqueue(
875 IntegerLiteral::LowerOrEqual(objective_->objective_var,
876 objective - 1),
877 {}, {})) {
879 }
880
881 continue;
882 }
883 const Literal decision(decision_index);
884 CHECK(!sat_solver_->Assignment().LiteralIsFalse(decision));
885 CHECK(!sat_solver_->Assignment().LiteralIsTrue(decision));
886 if (!helper_->TakeDecision(decision)) {
887 return sat_solver_->UnsatStatus();
888 }
889 MaybeProposeSplit();
890 }
891
893}
894
895Literal SharedTreeWorker::DecodeDecision(ProtoLiteral lit) {
896 return lit.Decode(mapping_, encoder_);
897}
898
899std::optional<ProtoLiteral> SharedTreeWorker::EncodeDecision(Literal decision) {
900 return ProtoLiteral::Encode(decision, mapping_, encoder_);
901}
902
903} // namespace operations_research::sat
int GetProtoVariableFromBooleanVariable(BooleanVariable var) const
int NumProtoVariables() const
Returns the number of variables in the loaded proto.
IntegerVariable Integer(int ref) const
const InlinedIntegerLiteralVector & GetIntegerLiterals(Literal lit) const
Returns the IntegerLiterals that were associated with the given Literal.
Definition integer.h:213
Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit)
Definition integer.cc:274
An helper class to share the code used by the different kind of search.
bool ReportConflict(absl::Span< const Literal > literal_reason, absl::Span< const IntegerLiteral > integer_reason)
Definition integer.h:705
LiteralIndex Index() const
Definition sat_base.h:91
BooleanVariable Variable() const
Definition sat_base.h:87
static std::optional< ProtoLiteral > Encode(Literal, CpModelMapping *, IntegerEncoder *)
Literal Decode(CpModelMapping *, IntegerEncoder *) const
Note you should only decode integer literals at the root level.
static std::optional< ProtoLiteral > EncodeLiteral(Literal, CpModelMapping *)
int MaxLevel() const
Returns the maximum decision level stored in the trail.
void Clear()
Clear the trail, removing all levels.
void SetTargetPhase(absl::Span< const ProtoLiteral > phase)
ProtoLiteral Decision(int level) const
Returns the decision assigned at level.
void SetObjectiveLb(int level, IntegerValue objective_lb)
Set a lower bound on the objective at level.
absl::Span< const ProtoLiteral > Implications(int level) const
absl::Span< const int > NodeIds(int level) const
Returns the node ids for decisions and implications at level.
void AddImplication(int level, ProtoLiteral implication)
IntegerValue ObjectiveLb(int level) const
void PushLevel(const ProtoLiteral &decision, IntegerValue objective_lb, int node_id)
Adds a new assigned level to the trail.
const std::vector< ProtoLiteral > & TargetPhase() const
void SetLevelImplied(int level)
Asserts that the decision at level is implied by earlier decisions.
Contain the logic to decide when to restart a SAT tree search.
Definition restart.h:32
const VariablesAssignment & Assignment() const
Definition sat_solver.h:393
void CloseTree(ProtoTrail &path, int level)
void ProposeSplit(ProtoTrail &path, ProtoLiteral decision)
void ReplaceTree(ProtoTrail &path)
Assigns a path prefix that the worker should explore.
bool SyncTree(ProtoTrail &path) ABSL_LOCKS_EXCLUDED(mu_)
int NumNodes() const ABSL_LOCKS_EXCLUDED(mu_)
SatSolver::Status Search(const std::function< void()> &feasible_solution_observer)
bool LiteralIsFalse(Literal literal) const
Definition sat_base.h:185
bool LiteralIsTrue(Literal literal) const
Definition sat_base.h:188
constexpr IntegerValue kMaxIntegerValue(std::numeric_limits< IntegerValue::ValueType >::max() - 1)
const LiteralIndex kNoLiteralIndex(-1)
const IntegerVariable kNoIntegerVariable(-1)
IntegerVariable PositiveVariable(IntegerVariable i)
int NegatedRef(int ref)
Small utility functions to deal with negative variable/literal references.
STL namespace.
static IntegerLiteral GreaterOrEqual(IntegerVariable i, IntegerValue bound)
static IntegerLiteral LowerOrEqual(IntegerVariable i, IntegerValue bound)