194#ifndef OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
195#define OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
207#include "absl/flags/declare.h"
208#include "absl/flags/flag.h"
209#include "absl/strings/str_format.h"
224template <
typename GraphType>
251 DCHECK(graph_ ==
nullptr);
284 inline const GraphType&
Graph()
const {
return *graph_; }
295 DCHECK_EQ(0, scaled_arc_cost_[
arc] % cost_scaling_factor_);
296 return scaled_arc_cost_[
arc] / cost_scaling_factor_;
327 if (graph_ ==
nullptr) {
332 return graph_->num_nodes();
342 DCHECK_LT(left_node, num_left_nodes_);
343 return matched_arc_[left_node];
354 DCHECK_LT(left_node, num_left_nodes_);
356 DCHECK_NE(GraphType::kNilArc, matching_arc);
357 return Head(matching_arc);
360 std::string
StatsString()
const {
return total_stats_.StatsString(); }
365 : num_left_nodes_(num_left_nodes), node_iterator_(0) {}
368 : num_left_nodes_(assignment.
NumLeftNodes()), node_iterator_(0) {}
372 bool Ok()
const {
return node_iterator_ < num_left_nodes_; }
374 void Next() { ++node_iterator_; }
378 typename GraphType::NodeIndex node_iterator_;
383 Stats() : pushes_(0), double_pushes_(0), relabelings_(0), refinements_(0) {}
390 void Add(
const Stats& that) {
391 pushes_ += that.pushes_;
392 double_pushes_ += that.double_pushes_;
393 relabelings_ += that.relabelings_;
394 refinements_ += that.refinements_;
396 std::string StatsString()
const {
397 return absl::StrFormat(
398 "%d refinements; %d relabelings; "
399 "%d double pushes; %d pushes",
400 refinements_, relabelings_, double_pushes_, pushes_);
403 int64_t double_pushes_;
404 int64_t relabelings_;
405 int64_t refinements_;
409 class ActiveNodeContainerInterface {
411 virtual ~ActiveNodeContainerInterface() {}
412 virtual bool Empty()
const = 0;
417 class ActiveNodeStack :
public ActiveNodeContainerInterface {
419 ~ActiveNodeStack()
override {}
421 bool Empty()
const override {
return v_.empty(); }
423 void Add(
NodeIndex node)
override { v_.push_back(node); }
433 std::vector<NodeIndex> v_;
436 class ActiveNodeQueue :
public ActiveNodeContainerInterface {
438 ~ActiveNodeQueue()
override {}
440 bool Empty()
const override {
return q_.empty(); }
442 void Add(
NodeIndex node)
override { q_.push_front(node); }
452 std::deque<NodeIndex> q_;
463 typedef std::pair<ArcIndex, CostValue> ImplicitPriceSummary;
467 bool EpsilonOptimal()
const;
471 bool AllMatched()
const;
478 inline ImplicitPriceSummary BestArcAndGap(
NodeIndex left_node)
const;
482 void ReportAndAccumulateStats();
493 bool UpdateEpsilon();
497 inline bool IsActive(
NodeIndex left_node)
const;
504 inline bool IsActiveForDebugging(
NodeIndex node)
const;
511 void InitializeActiveNodeContainer();
519 void SaturateNegativeArcs();
528 return scaled_arc_cost_[
arc] - price_[
Head(
arc)];
533 const GraphType* graph_;
542 bool incidence_precondition_satisfied_;
869 bool* in_range)
const {
882 const double result =
883 static_cast<double>(std::max<CostValue>(1, n / 2 - 1)) *
884 (
static_cast<double>(old_epsilon) +
static_cast<double>(new_epsilon));
886 static_cast<double>(std::numeric_limits<CostValue>::max());
887 if (result > limit) {
889 if (in_range !=
nullptr) *in_range =
false;
890 return std::numeric_limits<CostValue>::max();
910 CostValue largest_scaled_cost_magnitude_;
923 ZVector<CostValue> price_;
928 std::vector<ArcIndex> matched_arc_;
936 ZVector<NodeIndex> matched_node_;
941 std::vector<CostValue> scaled_arc_cost_;
946 std::unique_ptr<ActiveNodeContainerInterface> active_nodes_;
954 Stats iteration_stats_;
960template <
typename GraphType>
961const CostValue LinearSumAssignment<GraphType>::kMinEpsilon = 1;
963template <
typename GraphType>
967 num_left_nodes_(num_left_nodes),
969 cost_scaling_factor_(1 + num_left_nodes),
970 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
972 price_lower_bound_(0),
973 slack_relabeling_price_(0),
974 largest_scaled_cost_magnitude_(0),
976 price_(num_left_nodes, 2 * num_left_nodes - 1),
977 matched_arc_(num_left_nodes, 0),
978 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
979 scaled_arc_cost_(
graph.max_end_arc_index(), 0),
980 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
981 ? static_cast<ActiveNodeContainerInterface*>(
982 new ActiveNodeStack())
983 : static_cast<ActiveNodeContainerInterface*>(
984 new ActiveNodeQueue())) {}
986template <
typename GraphType>
990 num_left_nodes_(num_left_nodes),
992 cost_scaling_factor_(1 + num_left_nodes),
993 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
995 price_lower_bound_(0),
996 slack_relabeling_price_(0),
997 largest_scaled_cost_magnitude_(0),
999 price_(num_left_nodes, 2 * num_left_nodes - 1),
1000 matched_arc_(num_left_nodes, 0),
1001 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
1002 scaled_arc_cost_(num_arcs, 0),
1003 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
1004 ? static_cast<ActiveNodeContainerInterface*>(
1005 new ActiveNodeStack())
1006 : static_cast<ActiveNodeContainerInterface*>(
1007 new ActiveNodeQueue())) {}
1009template <
typename GraphType>
1011 if (graph_ !=
nullptr) {
1013 DCHECK_LT(
arc, graph_->num_arcs());
1015 DCHECK_LE(num_left_nodes_,
head);
1017 cost *= cost_scaling_factor_;
1018 const CostValue cost_magnitude = std::abs(cost);
1019 largest_scaled_cost_magnitude_ =
1020 std::max(largest_scaled_cost_magnitude_, cost_magnitude);
1021 scaled_arc_cost_[
arc] = cost;
1024template <
typename ArcIndexType>
1028 : temp_(0), cost_(cost) {}
1035 temp_ = (*cost_)[source];
1039 ArcIndexType destination)
const override {
1040 (*cost_)[destination] = (*cost_)[source];
1044 (*cost_)[destination] = temp_;
1051 std::vector<CostValue>*
const cost_;
1059template <
typename GraphType>
1068 typename GraphType::ArcIndex
b)
const {
1069 return ((graph_.Tail(
a) < graph_.Tail(
b)) ||
1070 ((graph_.Tail(
a) == graph_.Tail(
b)) &&
1071 (graph_.Head(
a) < graph_.Head(
b))));
1075 const GraphType& graph_;
1083template <
typename GraphType>
1084PermutationCycleHandler<typename GraphType::ArcIndex>*
1090template <
typename GraphType>
1095 DCHECK_EQ(graph_,
graph);
1101 graph->GroupForwardArcsByFunctor(compare, &cycle_handler);
1105template <
typename GraphType>
1107 const CostValue current_epsilon)
const {
1108 return std::max(current_epsilon / alpha_, kMinEpsilon);
1111template <
typename GraphType>
1112bool LinearSumAssignment<GraphType>::UpdateEpsilon() {
1113 CostValue new_epsilon = NewEpsilon(epsilon_);
1114 slack_relabeling_price_ = PriceChangeBound(epsilon_, new_epsilon,
nullptr);
1115 epsilon_ = new_epsilon;
1116 VLOG(3) <<
"Updated: epsilon_ == " << epsilon_;
1117 VLOG(4) <<
"slack_relabeling_price_ == " << slack_relabeling_price_;
1118 DCHECK_GT(slack_relabeling_price_, 0);
1126template <
typename GraphType>
1127inline bool LinearSumAssignment<GraphType>::IsActive(
1129 DCHECK_LT(left_node, num_left_nodes_);
1130 return matched_arc_[left_node] == GraphType::kNilArc;
1136template <
typename GraphType>
1137inline bool LinearSumAssignment<GraphType>::IsActiveForDebugging(
1139 if (node < num_left_nodes_) {
1140 return IsActive(node);
1142 return matched_node_[node] == GraphType::kNilNode;
1146template <
typename GraphType>
1147void LinearSumAssignment<GraphType>::InitializeActiveNodeContainer() {
1148 DCHECK(active_nodes_->Empty());
1149 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1150 node_it.Ok(); node_it.Next()) {
1152 if (IsActive(node)) {
1153 active_nodes_->Add(node);
1168template <
typename GraphType>
1169void LinearSumAssignment<GraphType>::SaturateNegativeArcs() {
1171 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1172 node_it.Ok(); node_it.Next()) {
1174 if (IsActive(node)) {
1182 matched_arc_[node] = GraphType::kNilArc;
1183 matched_node_[mate] = GraphType::kNilNode;
1189template <
typename GraphType>
1190bool LinearSumAssignment<GraphType>::DoublePush(
NodeIndex source) {
1191 DCHECK_GT(num_left_nodes_, source);
1192 DCHECK(IsActive(source)) <<
"Node " << source
1193 <<
"must be active (unmatched)!";
1194 ImplicitPriceSummary summary = BestArcAndGap(source);
1195 const ArcIndex best_arc = summary.first;
1200 if (best_arc == GraphType::kNilArc) {
1203 const NodeIndex new_mate = Head(best_arc);
1204 const NodeIndex to_unmatch = matched_node_[new_mate];
1205 if (to_unmatch != GraphType::kNilNode) {
1208 matched_arc_[to_unmatch] = GraphType::kNilArc;
1209 active_nodes_->Add(to_unmatch);
1211 iteration_stats_.double_pushes_ += 1;
1216 iteration_stats_.pushes_ += 1;
1218 matched_arc_[source] = best_arc;
1219 matched_node_[new_mate] = source;
1221 iteration_stats_.relabelings_ += 1;
1222 const CostValue new_price = price_[new_mate] - gap - epsilon_;
1223 price_[new_mate] = new_price;
1224 return new_price >= price_lower_bound_;
1227template <
typename GraphType>
1228bool LinearSumAssignment<GraphType>::Refine() {
1229 SaturateNegativeArcs();
1230 InitializeActiveNodeContainer();
1231 while (total_excess_ > 0) {
1234 const NodeIndex node = active_nodes_->Get();
1235 if (!DoublePush(node)) {
1243 LOG_IF(DFATAL, total_stats_.refinements_ > 0)
1244 <<
"Infeasibility detection triggered after first iteration found "
1245 <<
"a feasible assignment!";
1249 DCHECK(active_nodes_->Empty());
1250 iteration_stats_.refinements_ += 1;
1268template <
typename GraphType>
1269inline typename LinearSumAssignment<GraphType>::ImplicitPriceSummary
1270LinearSumAssignment<GraphType>::BestArcAndGap(
NodeIndex left_node)
const {
1271 DCHECK(IsActive(left_node))
1272 <<
"Node " << left_node <<
" must be active (unmatched)!";
1273 DCHECK_GT(epsilon_, 0);
1274 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1275 ArcIndex best_arc = arc_it.Index();
1276 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1282 const CostValue max_gap = slack_relabeling_price_ - epsilon_;
1283 CostValue second_min_partial_reduced_cost =
1284 min_partial_reduced_cost + max_gap;
1285 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1287 const CostValue partial_reduced_cost = PartialReducedCost(
arc);
1288 if (partial_reduced_cost < second_min_partial_reduced_cost) {
1289 if (partial_reduced_cost < min_partial_reduced_cost) {
1291 second_min_partial_reduced_cost = min_partial_reduced_cost;
1292 min_partial_reduced_cost = partial_reduced_cost;
1294 second_min_partial_reduced_cost = partial_reduced_cost;
1298 const CostValue gap = std::min<CostValue>(
1299 second_min_partial_reduced_cost - min_partial_reduced_cost, max_gap);
1301 return std::make_pair(best_arc, gap);
1308template <
typename GraphType>
1309inline CostValue LinearSumAssignment<GraphType>::ImplicitPrice(
1311 DCHECK_GT(num_left_nodes_, left_node);
1312 DCHECK_GT(epsilon_, 0);
1313 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1315 DCHECK(arc_it.Ok());
1316 ArcIndex best_arc = arc_it.Index();
1317 if (best_arc == matched_arc_[left_node]) {
1320 best_arc = arc_it.Index();
1323 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1329 return -(min_partial_reduced_cost + slack_relabeling_price_);
1331 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1333 if (
arc != matched_arc_[left_node]) {
1334 const CostValue partial_reduced_cost = PartialReducedCost(
arc);
1335 if (partial_reduced_cost < min_partial_reduced_cost) {
1336 min_partial_reduced_cost = partial_reduced_cost;
1340 return -min_partial_reduced_cost;
1344template <
typename GraphType>
1345bool LinearSumAssignment<GraphType>::AllMatched()
const {
1346 for (
NodeIndex node = 0; node < graph_->num_nodes(); ++node) {
1347 if (IsActiveForDebugging(node)) {
1355template <
typename GraphType>
1356bool LinearSumAssignment<GraphType>::EpsilonOptimal()
const {
1357 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1358 node_it.Ok(); node_it.Next()) {
1359 const NodeIndex left_node = node_it.Index();
1362 CostValue left_node_price = ImplicitPrice(left_node);
1363 for (
typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1364 arc_it.Ok(); arc_it.Next()) {
1366 const CostValue reduced_cost = left_node_price + PartialReducedCost(
arc);
1371 if (matched_arc_[left_node] ==
arc) {
1375 if (reduced_cost > epsilon_) {
1381 if (reduced_cost < 0) {
1390template <
typename GraphType>
1392 incidence_precondition_satisfied_ =
true;
1396 epsilon_ = std::max(largest_scaled_cost_magnitude_, kMinEpsilon + 1);
1397 VLOG(2) <<
"Largest given cost magnitude: "
1398 << largest_scaled_cost_magnitude_ / cost_scaling_factor_;
1401 for (
NodeIndex node = 0; node < num_left_nodes_; ++node) {
1402 matched_arc_[node] = GraphType::kNilArc;
1403 typename GraphType::OutgoingArcIterator arc_it(*graph_, node);
1405 incidence_precondition_satisfied_ =
false;
1410 for (
NodeIndex node = num_left_nodes_; node < graph_->num_nodes(); ++node) {
1412 matched_node_[node] = GraphType::kNilNode;
1414 bool in_range =
true;
1415 double double_price_lower_bound = 0.0;
1417 CostValue old_error_parameter = epsilon_;
1419 new_error_parameter = NewEpsilon(old_error_parameter);
1420 double_price_lower_bound -=
1421 2.0 *
static_cast<double>(PriceChangeBound(
1422 old_error_parameter, new_error_parameter, &in_range));
1423 old_error_parameter = new_error_parameter;
1424 }
while (new_error_parameter != kMinEpsilon);
1425 const double limit =
1426 -
static_cast<double>(std::numeric_limits<CostValue>::max());
1427 if (double_price_lower_bound < limit) {
1429 price_lower_bound_ = -std::numeric_limits<CostValue>::max();
1431 price_lower_bound_ =
static_cast<CostValue>(double_price_lower_bound);
1433 VLOG(4) <<
"price_lower_bound_ == " << price_lower_bound_;
1434 DCHECK_LE(price_lower_bound_, 0);
1436 LOG(WARNING) <<
"Price change bound exceeds range of representable "
1437 <<
"costs; arithmetic overflow is not ruled out and "
1438 <<
"infeasibility might go undetected.";
1443template <
typename GraphType>
1445 total_stats_.Add(iteration_stats_);
1446 VLOG(3) <<
"Iteration stats: " << iteration_stats_.
StatsString();
1447 iteration_stats_.Clear();
1450template <
typename GraphType>
1452 CHECK(graph_ !=
nullptr);
1453 bool ok = graph_->num_nodes() == 2 * num_left_nodes_;
1454 if (!ok)
return false;
1461 ok = ok && incidence_precondition_satisfied_;
1462 DCHECK(!ok || EpsilonOptimal());
1463 while (ok && epsilon_ > kMinEpsilon) {
1464 ok = ok && UpdateEpsilon();
1465 ok = ok && Refine();
1466 ReportAndAccumulateStats();
1467 DCHECK(!ok || EpsilonOptimal());
1468 DCHECK(!ok || AllMatched());
1471 VLOG(1) <<
"Overall stats: " << total_stats_.StatsString();
1475template <
typename GraphType>
1482 cost += GetAssignmentCost(node_it.Index());
bool operator()(typename GraphType::ArcIndex a, typename GraphType::ArcIndex b) const
ArcIndexOrderingByTailNode(const GraphType &graph)
CostValueCycleHandler(const CostValueCycleHandler &)=delete
This type is neither copyable nor movable.
CostValueCycleHandler(std::vector< CostValue > *cost)
void SetIndexFromTemp(ArcIndexType destination) const override
Sets a data element from the temporary.
void SetIndexFromIndex(ArcIndexType source, ArcIndexType destination) const override
Moves a data element one step along its cycle.
~CostValueCycleHandler() override
CostValueCycleHandler & operator=(const CostValueCycleHandler &)=delete
void SetTempFromIndex(ArcIndexType source) override
BipartiteLeftNodeIterator(const LinearSumAssignment &assignment)
BipartiteLeftNodeIterator(const GraphType &graph, NodeIndex num_left_nodes)
This class does not take ownership of its underlying graph.
NodeIndex Head(ArcIndex arc) const
NodeIndex NumLeftNodes() const
void OptimizeGraphLayout(GraphType *graph)
NodeIndex NumNodes() const
Returns the total number of nodes in the given problem.
LinearSumAssignment(const GraphType &graph, NodeIndex num_left_nodes)
std::string StatsString() const
GraphType::ArcIndex ArcIndex
void SetArcCost(ArcIndex arc, CostValue cost)
Sets the cost of an arc already present in the given graph.
void SetGraph(const GraphType *graph)
CostValue GetCost() const
LinearSumAssignment(const LinearSumAssignment &)=delete
This type is neither copyable nor movable.
GraphType::NodeIndex NodeIndex
void SetCostScalingDivisor(CostValue factor)
NodeIndex GetMate(NodeIndex left_node) const
Returns the node to which the given node is matched.
CostValue GetAssignmentCost(NodeIndex node) const
LinearSumAssignment & operator=(const LinearSumAssignment &)=delete
ArcIndex GetAssignmentArc(NodeIndex left_node) const
Returns the arc through which the given node is matched.
CostValue ArcCost(ArcIndex arc) const
const GraphType & Graph() const
Allows tests, iterators, etc., to inspect our underlying graph.
operations_research::PermutationCycleHandler< typename GraphType::ArcIndex > * ArcAnnotationCycleHandler()
Passes ownership of the cycle handler to the caller.
bool BuildTailArrayFromAdjacencyListsIfForwardGraph() const
void ReleaseTailArrayIfForwardGraph() const
ABSL_DECLARE_FLAG(int64_t, assignment_alpha)
In SWIG mode, we don't want anything besides these top-level includes.