194#ifndef OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
195#define OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
207#include "absl/flags/declare.h"
208#include "absl/flags/flag.h"
209#include "absl/strings/str_format.h"
225template <
typename GraphType,
typename CostValue =
int64_t>
253 DCHECK(graph_ ==
nullptr);
271 inline const GraphType&
Graph()
const {
return *graph_; }
282 DCHECK_EQ(0, scaled_arc_cost_[arc] % cost_scaling_factor_);
283 return scaled_arc_cost_[arc] / cost_scaling_factor_;
314 if (graph_ ==
nullptr) {
319 return graph_->num_nodes();
329 DCHECK_LT(left_node, num_left_nodes_);
330 return matched_arc_[left_node];
341 DCHECK_LT(left_node, num_left_nodes_);
343 DCHECK_NE(GraphType::kNilArc, matching_arc);
344 return Head(matching_arc);
347 std::string
StatsString()
const {
return total_stats_.StatsString(); }
351 return ::util::IntegerRange<NodeIndex>(0, num_left_nodes_);
362 Stats() : pushes_(0), double_pushes_(0), relabelings_(0), refinements_(0) {}
369 void Add(
const Stats& that) {
370 pushes_ += that.pushes_;
371 double_pushes_ += that.double_pushes_;
372 relabelings_ += that.relabelings_;
373 refinements_ += that.refinements_;
375 std::string StatsString()
const {
376 return absl::StrFormat(
377 "%d refinements; %d relabelings; "
378 "%d double pushes; %d pushes",
379 refinements_, relabelings_, double_pushes_, pushes_);
382 int64_t double_pushes_;
383 int64_t relabelings_;
384 int64_t refinements_;
388 class ActiveNodeContainerInterface {
390 virtual ~ActiveNodeContainerInterface() {}
391 virtual bool Empty()
const = 0;
396 class ActiveNodeStack :
public ActiveNodeContainerInterface {
398 ~ActiveNodeStack()
override {}
400 bool Empty()
const override {
return v_.empty(); }
402 void Add(
NodeIndex node)
override { v_.push_back(node); }
412 std::vector<NodeIndex> v_;
415 class ActiveNodeQueue :
public ActiveNodeContainerInterface {
417 ~ActiveNodeQueue()
override {}
419 bool Empty()
const override {
return q_.empty(); }
421 void Add(
NodeIndex node)
override { q_.push_front(node); }
431 std::deque<NodeIndex> q_;
442 typedef std::pair<ArcIndex, CostValue> ImplicitPriceSummary;
446 bool AllMatched()
const;
453 inline ImplicitPriceSummary BestArcAndGap(
NodeIndex left_node)
const;
457 void ReportAndAccumulateStats();
468 bool UpdateEpsilon();
472 inline bool IsActive(
NodeIndex left_node)
const;
479 inline bool IsActiveForDebugging(
NodeIndex node)
const;
486 void InitializeActiveNodeContainer();
494 void SaturateNegativeArcs();
503 return scaled_arc_cost_[arc] - price_[
Head(arc)];
508 const GraphType* graph_;
517 bool incidence_precondition_satisfied_;
534 static constexpr CostValue kMinEpsilon = 1;
844 bool* in_range)
const {
857 const double result =
858 static_cast<double>(std::max<CostValue>(1, n / 2 - 1)) *
859 (
static_cast<double>(old_epsilon) +
static_cast<double>(new_epsilon));
861 static_cast<double>(std::numeric_limits<CostValue>::max());
862 if (result > limit) {
864 if (in_range !=
nullptr) *in_range =
false;
865 return std::numeric_limits<CostValue>::max();
885 CostValue largest_scaled_cost_magnitude_;
898 ZVector<CostValue> price_;
903 std::vector<ArcIndex> matched_arc_;
911 ZVector<NodeIndex> matched_node_;
916 std::vector<CostValue> scaled_arc_cost_;
921 std::unique_ptr<ActiveNodeContainerInterface> active_nodes_;
929 Stats iteration_stats_;
935template <
typename GraphType,
typename CostValue>
937 const GraphType& graph,
const NodeIndex num_left_nodes)
939 num_left_nodes_(num_left_nodes),
941 cost_scaling_factor_(1 + num_left_nodes),
942 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
944 price_lower_bound_(0),
945 slack_relabeling_price_(0),
946 largest_scaled_cost_magnitude_(0),
948 price_(num_left_nodes, 2 * num_left_nodes - 1),
949 matched_arc_(num_left_nodes, 0),
950 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
951 scaled_arc_cost_(graph.arc_capacity(), 0),
952 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
953 ? static_cast<ActiveNodeContainerInterface*>(
954 new ActiveNodeStack())
955 : static_cast<ActiveNodeContainerInterface*>(
956 new ActiveNodeQueue())) {}
958template <
typename GraphType,
typename CostValue>
962 num_left_nodes_(num_left_nodes),
964 cost_scaling_factor_(1 + num_left_nodes),
965 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
967 price_lower_bound_(0),
968 slack_relabeling_price_(0),
969 largest_scaled_cost_magnitude_(0),
971 price_(num_left_nodes, 2 * num_left_nodes - 1),
972 matched_arc_(num_left_nodes, 0),
973 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
974 scaled_arc_cost_(num_arcs, 0),
975 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
976 ? static_cast<ActiveNodeContainerInterface*>(
977 new ActiveNodeStack())
978 : static_cast<ActiveNodeContainerInterface*>(
979 new ActiveNodeQueue())) {}
981template <
typename GraphType,
typename CostValue>
984 if (graph_ !=
nullptr) {
986 DCHECK_LT(arc, graph_->num_arcs());
988 DCHECK_LE(num_left_nodes_, head);
990 cost *= cost_scaling_factor_;
991 const CostValue cost_magnitude = std::abs(cost);
992 largest_scaled_cost_magnitude_ =
993 std::max(largest_scaled_cost_magnitude_, cost_magnitude);
994 scaled_arc_cost_[arc] = cost;
997template <
typename ArcIndexType,
typename CostValue>
1001 : temp_(0), cost_(cost) {}
1008 temp_ = (*cost_)[source];
1012 ArcIndexType destination)
const override {
1013 (*cost_)[destination] = (*cost_)[source];
1017 (*cost_)[destination] = temp_;
1024 std::vector<CostValue>*
const cost_;
1032template <
typename GraphType>
1041 typename GraphType::ArcIndex b)
const {
1042 return ((graph_.Tail(a) < graph_.Tail(b)) ||
1043 ((graph_.Tail(a) == graph_.Tail(b)) &&
1044 (graph_.Head(a) < graph_.Head(b))));
1048 const GraphType& graph_;
1056template <
typename GraphType,
typename CostValue>
1057PermutationCycleHandler<typename GraphType::ArcIndex>*
1063template <
typename GraphType,
typename CostValue>
1064CostValue LinearSumAssignment<GraphType, CostValue>::NewEpsilon(
1065 const CostValue current_epsilon)
const {
1066 return std::max(current_epsilon / alpha_, kMinEpsilon);
1069template <
typename GraphType,
typename CostValue>
1070bool LinearSumAssignment<GraphType, CostValue>::UpdateEpsilon() {
1071 CostValue new_epsilon = NewEpsilon(epsilon_);
1072 slack_relabeling_price_ = PriceChangeBound(epsilon_, new_epsilon,
nullptr);
1073 epsilon_ = new_epsilon;
1074 VLOG(3) <<
"Updated: epsilon_ == " << epsilon_;
1075 VLOG(4) <<
"slack_relabeling_price_ == " << slack_relabeling_price_;
1076 DCHECK_GT(slack_relabeling_price_, 0);
1084template <
typename GraphType,
typename CostValue>
1085inline bool LinearSumAssignment<GraphType, CostValue>::IsActive(
1087 DCHECK_LT(left_node, num_left_nodes_);
1088 return matched_arc_[left_node] == GraphType::kNilArc;
1094template <
typename GraphType,
typename CostValue>
1095inline bool LinearSumAssignment<GraphType, CostValue>::IsActiveForDebugging(
1097 if (node < num_left_nodes_) {
1098 return IsActive(node);
1100 return matched_node_[node] == GraphType::kNilNode;
1104template <
typename GraphType,
typename CostValue>
1106 CostValue>::InitializeActiveNodeContainer() {
1107 DCHECK(active_nodes_->Empty());
1108 for (
const NodeIndex node : BipartiteLeftNodes()) {
1109 if (IsActive(node)) {
1110 active_nodes_->Add(node);
1125template <
typename GraphType,
typename CostValue>
1126void LinearSumAssignment<GraphType, CostValue>::SaturateNegativeArcs() {
1128 for (
const NodeIndex node : BipartiteLeftNodes()) {
1129 if (IsActive(node)) {
1137 matched_arc_[node] = GraphType::kNilArc;
1138 matched_node_[mate] = GraphType::kNilNode;
1144template <
typename GraphType,
typename CostValue>
1145bool LinearSumAssignment<GraphType, CostValue>::DoublePush(
NodeIndex source) {
1146 DCHECK_GT(num_left_nodes_, source);
1147 DCHECK(IsActive(source)) <<
"Node " << source
1148 <<
"must be active (unmatched)!";
1149 ImplicitPriceSummary summary = BestArcAndGap(source);
1150 const ArcIndex best_arc = summary.first;
1155 if (best_arc == GraphType::kNilArc) {
1158 const NodeIndex new_mate = Head(best_arc);
1159 const NodeIndex to_unmatch = matched_node_[new_mate];
1160 if (to_unmatch != GraphType::kNilNode) {
1163 matched_arc_[to_unmatch] = GraphType::kNilArc;
1164 active_nodes_->Add(to_unmatch);
1166 iteration_stats_.double_pushes_ += 1;
1171 iteration_stats_.pushes_ += 1;
1173 matched_arc_[source] = best_arc;
1174 matched_node_[new_mate] = source;
1176 iteration_stats_.relabelings_ += 1;
1177 const CostValue new_price = price_[new_mate] - gap - epsilon_;
1178 price_[new_mate] = new_price;
1179 return new_price >= price_lower_bound_;
1182template <
typename GraphType,
typename CostValue>
1183bool LinearSumAssignment<GraphType, CostValue>::Refine() {
1184 SaturateNegativeArcs();
1185 InitializeActiveNodeContainer();
1186 while (total_excess_ > 0) {
1189 const NodeIndex node = active_nodes_->Get();
1190 if (!DoublePush(node)) {
1198 LOG_IF(DFATAL, total_stats_.refinements_ > 0)
1199 <<
"Infeasibility detection triggered after first iteration found "
1200 <<
"a feasible assignment!";
1204 DCHECK(active_nodes_->Empty());
1205 iteration_stats_.refinements_ += 1;
1223template <
typename GraphType,
typename CostValue>
1224inline typename LinearSumAssignment<GraphType, CostValue>::ImplicitPriceSummary
1225LinearSumAssignment<GraphType, CostValue>::BestArcAndGap(
1227 DCHECK(IsActive(left_node))
1228 <<
"Node " << left_node <<
" must be active (unmatched)!";
1229 DCHECK_GT(epsilon_, 0);
1230 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1231 ArcIndex best_arc = arc_it.Index();
1232 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1238 const CostValue max_gap = slack_relabeling_price_ - epsilon_;
1239 CostValue second_min_partial_reduced_cost =
1240 min_partial_reduced_cost + max_gap;
1241 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1242 const ArcIndex arc = arc_it.Index();
1243 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1244 if (partial_reduced_cost < second_min_partial_reduced_cost) {
1245 if (partial_reduced_cost < min_partial_reduced_cost) {
1247 second_min_partial_reduced_cost = min_partial_reduced_cost;
1248 min_partial_reduced_cost = partial_reduced_cost;
1250 second_min_partial_reduced_cost = partial_reduced_cost;
1254 const CostValue gap = std::min<CostValue>(
1255 second_min_partial_reduced_cost - min_partial_reduced_cost, max_gap);
1257 return std::make_pair(best_arc, gap);
1264template <
typename GraphType,
typename CostValue>
1265inline CostValue LinearSumAssignment<GraphType, CostValue>::ImplicitPrice(
1267 DCHECK_GT(num_left_nodes_, left_node);
1268 DCHECK_GT(epsilon_, 0);
1269 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1271 DCHECK(arc_it.Ok());
1272 ArcIndex best_arc = arc_it.Index();
1273 if (best_arc == matched_arc_[left_node]) {
1276 best_arc = arc_it.Index();
1279 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1285 return -(min_partial_reduced_cost + slack_relabeling_price_);
1287 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1288 const ArcIndex arc = arc_it.Index();
1289 if (arc != matched_arc_[left_node]) {
1290 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1291 if (partial_reduced_cost < min_partial_reduced_cost) {
1292 min_partial_reduced_cost = partial_reduced_cost;
1296 return -min_partial_reduced_cost;
1300template <
typename GraphType,
typename CostValue>
1301bool LinearSumAssignment<GraphType, CostValue>::AllMatched()
const {
1302 for (
NodeIndex node = 0; node < graph_->num_nodes(); ++node) {
1303 if (IsActiveForDebugging(node)) {
1311template <
typename GraphType,
typename CostValue>
1316 CostValue left_node_price = ImplicitPrice(left_node);
1317 for (
typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1318 arc_it.Ok(); arc_it.Next()) {
1319 const ArcIndex arc = arc_it.Index();
1320 const CostValue reduced_cost = left_node_price + PartialReducedCost(arc);
1325 if (matched_arc_[left_node] == arc) {
1329 if (reduced_cost > epsilon_) {
1335 if (reduced_cost < 0) {
1344template <
typename GraphType,
typename CostValue>
1346 incidence_precondition_satisfied_ =
true;
1350 epsilon_ = std::max(largest_scaled_cost_magnitude_, kMinEpsilon + 1);
1351 VLOG(2) <<
"Largest given cost magnitude: "
1352 << largest_scaled_cost_magnitude_ / cost_scaling_factor_;
1355 for (
NodeIndex node = 0; node < num_left_nodes_; ++node) {
1356 matched_arc_[node] = GraphType::kNilArc;
1357 typename GraphType::OutgoingArcIterator arc_it(*graph_, node);
1359 incidence_precondition_satisfied_ =
false;
1364 for (
NodeIndex node = num_left_nodes_; node < graph_->num_nodes(); ++node) {
1366 matched_node_[node] = GraphType::kNilNode;
1368 bool in_range =
true;
1369 double double_price_lower_bound = 0.0;
1371 CostValue old_error_parameter = epsilon_;
1373 new_error_parameter = NewEpsilon(old_error_parameter);
1374 double_price_lower_bound -=
1375 2.0 *
static_cast<double>(PriceChangeBound(
1376 old_error_parameter, new_error_parameter, &in_range));
1377 old_error_parameter = new_error_parameter;
1378 }
while (new_error_parameter != kMinEpsilon);
1379 const double limit =
1380 -
static_cast<double>(std::numeric_limits<CostValue>::max());
1381 if (double_price_lower_bound < limit) {
1383 price_lower_bound_ = -std::numeric_limits<CostValue>::max();
1385 price_lower_bound_ =
static_cast<CostValue>(double_price_lower_bound);
1387 VLOG(4) <<
"price_lower_bound_ == " << price_lower_bound_;
1388 DCHECK_LE(price_lower_bound_, 0);
1390 LOG(WARNING) <<
"Price change bound exceeds range of representable "
1391 <<
"costs; arithmetic overflow is not ruled out and "
1392 <<
"infeasibility might go undetected.";
1397template <
typename GraphType,
typename CostValue>
1398void LinearSumAssignment<GraphType, CostValue>::ReportAndAccumulateStats() {
1399 total_stats_.Add(iteration_stats_);
1400 VLOG(3) <<
"Iteration stats: " << iteration_stats_.StatsString();
1401 iteration_stats_.Clear();
1404template <
typename GraphType,
typename CostValue>
1406 CHECK(graph_ !=
nullptr);
1407 bool ok = graph_->num_nodes() == 2 * num_left_nodes_;
1408 if (!ok)
return false;
1415 ok = ok && incidence_precondition_satisfied_;
1417 while (ok && epsilon_ > kMinEpsilon) {
1418 ok = ok && UpdateEpsilon();
1419 ok = ok && Refine();
1420 ReportAndAccumulateStats();
1422 DCHECK(!ok || AllMatched());
1425 VLOG(1) <<
"Overall stats: " << total_stats_.StatsString();
1429template <
typename GraphType,
typename CostValue>
bool operator()(typename GraphType::ArcIndex a, typename GraphType::ArcIndex b) const
ArcIndexOrderingByTailNode(const GraphType &graph)
CostValueCycleHandler(const CostValueCycleHandler &)=delete
This type is neither copyable nor movable.
CostValueCycleHandler & operator=(const CostValueCycleHandler &)=delete
void SetTempFromIndex(ArcIndexType source) override
void SetIndexFromTemp(ArcIndexType destination) const override
Sets a data element from the temporary.
~CostValueCycleHandler() override
CostValueCycleHandler(std::vector< CostValue > *cost)
void SetIndexFromIndex(ArcIndexType source, ArcIndexType destination) const override
Moves a data element one step along its cycle.
This class does not take ownership of its underlying graph.
void SetArcCost(ArcIndex arc, CostValue cost)
Sets the cost of an arc already present in the given graph.
CostValue GetAssignmentCost(NodeIndex node) const
operations_research::PermutationCycleHandler< typename GraphType::ArcIndex > * ArcAnnotationCycleHandler()
Passes ownership of the cycle handler to the caller.
std::string StatsString() const
NodeIndex NumLeftNodes() const
NodeIndex NumNodes() const
Returns the total number of nodes in the given problem.
NodeIndex Head(ArcIndex arc) const
ArcIndex GetAssignmentArc(NodeIndex left_node) const
Returns the arc through which the given node is matched.
const GraphType & Graph() const
Allows tests, iterators, etc., to inspect our underlying graph.
NodeIndex GetMate(NodeIndex left_node) const
Returns the node to which the given node is matched.
GraphType::NodeIndex NodeIndex
bool EpsilonOptimal() const
Only for debugging.
CostValue GetCost() const
::util::IntegerRange< NodeIndex > BipartiteLeftNodes() const
Returns the range of valid left node indices.
LinearSumAssignment(const LinearSumAssignment &)=delete
This type is neither copyable nor movable.
void SetGraph(const GraphType *graph)
CostValue ArcCost(ArcIndex arc) const
LinearSumAssignment(const GraphType &graph, NodeIndex num_left_nodes)
LinearSumAssignment & operator=(const LinearSumAssignment &)=delete
void SetCostScalingDivisor(CostValue factor)
GraphType::ArcIndex ArcIndex
PermutationCycleHandler(const PermutationCycleHandler &)=delete
OR_DLL ABSL_DECLARE_FLAG(int64_t, assignment_alpha)
In SWIG mode, we don't want anything besides these top-level includes.