193#ifndef OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
194#define OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
206#include "absl/flags/declare.h"
207#include "absl/flags/flag.h"
208#include "absl/strings/str_format.h"
224template <
typename GraphType,
typename CostValue =
int64_t>
252 DCHECK(graph_ ==
nullptr);
270 inline const GraphType&
Graph()
const {
return *graph_; }
281 DCHECK_EQ(0, scaled_arc_cost_[arc] % cost_scaling_factor_);
282 return scaled_arc_cost_[arc] / cost_scaling_factor_;
313 if (graph_ ==
nullptr) {
318 return graph_->num_nodes();
328 DCHECK_LT(left_node, num_left_nodes_);
329 return matched_arc_[left_node];
340 DCHECK_LT(left_node, num_left_nodes_);
342 DCHECK_NE(GraphType::kNilArc, matching_arc);
343 return Head(matching_arc);
346 std::string
StatsString()
const {
return total_stats_.StatsString(); }
350 return ::util::IntegerRange<NodeIndex>(0, num_left_nodes_);
361 Stats() : pushes(0), double_pushes(0), relabelings(0), refinements(0) {}
368 void Add(
const Stats& that) {
369 pushes += that.pushes;
370 double_pushes += that.double_pushes;
371 relabelings += that.relabelings;
372 refinements += that.refinements;
375 return absl::StrFormat(
376 "%d refinements; %d relabelings; "
377 "%d double pushes; %d pushes",
378 refinements, relabelings, double_pushes, pushes);
381 int64_t double_pushes;
387 class ActiveNodeContainerInterface {
389 virtual ~ActiveNodeContainerInterface() {}
390 virtual bool Empty()
const = 0;
395 class ActiveNodeStack :
public ActiveNodeContainerInterface {
397 ~ActiveNodeStack()
override {}
399 bool Empty()
const override {
return v_.empty(); }
401 void Add(
NodeIndex node)
override { v_.push_back(node); }
411 std::vector<NodeIndex> v_;
414 class ActiveNodeQueue :
public ActiveNodeContainerInterface {
416 ~ActiveNodeQueue()
override {}
418 bool Empty()
const override {
return q_.empty(); }
420 void Add(
NodeIndex node)
override { q_.push_front(node); }
430 std::deque<NodeIndex> q_;
441 typedef std::pair<ArcIndex, CostValue> ImplicitPriceSummary;
445 bool AllMatched()
const;
452 inline ImplicitPriceSummary BestArcAndGap(
NodeIndex left_node)
const;
456 void ReportAndAccumulateStats();
467 bool UpdateEpsilon();
471 inline bool IsActive(
NodeIndex left_node)
const;
478 inline bool IsActiveForDebugging(
NodeIndex node)
const;
485 void InitializeActiveNodeContainer();
493 void SaturateNegativeArcs();
502 return scaled_arc_cost_[arc] - price_[
Head(arc)];
507 const GraphType* graph_;
516 bool incidence_precondition_satisfied_;
533 static constexpr CostValue kMinEpsilon = 1;
843 bool* in_range)
const {
856 const double result =
857 static_cast<double>(std::max<CostValue>(1, n / 2 - 1)) *
858 (
static_cast<double>(old_epsilon) +
static_cast<double>(new_epsilon));
860 static_cast<double>(std::numeric_limits<CostValue>::max());
861 if (result > limit) {
863 if (in_range !=
nullptr) *in_range =
false;
864 return std::numeric_limits<CostValue>::max();
884 CostValue largest_scaled_cost_magnitude_;
897 ZVector<CostValue> price_;
902 std::vector<ArcIndex> matched_arc_;
910 ZVector<NodeIndex> matched_node_;
915 std::vector<CostValue> scaled_arc_cost_;
920 std::unique_ptr<ActiveNodeContainerInterface> active_nodes_;
928 Stats iteration_stats_;
934template <
typename GraphType,
typename CostValue>
936 const GraphType& graph,
const NodeIndex num_left_nodes)
938 num_left_nodes_(num_left_nodes),
940 cost_scaling_factor_(1 + num_left_nodes),
941 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
943 price_lower_bound_(0),
944 slack_relabeling_price_(0),
945 largest_scaled_cost_magnitude_(0),
947 price_(num_left_nodes, 2 * num_left_nodes - 1),
948 matched_arc_(num_left_nodes, 0),
949 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
950 scaled_arc_cost_(graph.arc_capacity(), 0),
951 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
952 ? static_cast<ActiveNodeContainerInterface*>(
953 new ActiveNodeStack())
954 : static_cast<ActiveNodeContainerInterface*>(
955 new ActiveNodeQueue())) {}
957template <
typename GraphType,
typename CostValue>
961 num_left_nodes_(num_left_nodes),
963 cost_scaling_factor_(1 + num_left_nodes),
964 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
966 price_lower_bound_(0),
967 slack_relabeling_price_(0),
968 largest_scaled_cost_magnitude_(0),
970 price_(num_left_nodes, 2 * num_left_nodes - 1),
971 matched_arc_(num_left_nodes, 0),
972 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
973 scaled_arc_cost_(num_arcs, 0),
974 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
975 ? static_cast<ActiveNodeContainerInterface*>(
976 new ActiveNodeStack())
977 : static_cast<ActiveNodeContainerInterface*>(
978 new ActiveNodeQueue())) {}
980template <
typename GraphType,
typename CostValue>
983 if (graph_ !=
nullptr) {
985 DCHECK_LT(arc, graph_->num_arcs());
987 DCHECK_LE(num_left_nodes_, head);
989 cost *= cost_scaling_factor_;
990 const CostValue cost_magnitude = std::abs(cost);
991 largest_scaled_cost_magnitude_ =
992 std::max(largest_scaled_cost_magnitude_, cost_magnitude);
993 scaled_arc_cost_[arc] = cost;
996template <
typename ArcIndexType,
typename CostValue>
1000 : temp_(0), cost_(cost) {}
1007 temp_ = (*cost_)[source];
1011 ArcIndexType destination)
const override {
1012 (*cost_)[destination] = (*cost_)[source];
1016 (*cost_)[destination] = temp_;
1023 std::vector<CostValue>*
const cost_;
1031template <
typename GraphType>
1040 typename GraphType::ArcIndex b)
const {
1041 return ((graph_.Tail(a) < graph_.Tail(b)) ||
1042 ((graph_.Tail(a) == graph_.Tail(b)) &&
1043 (graph_.Head(a) < graph_.Head(b))));
1047 const GraphType& graph_;
1055template <
typename GraphType,
typename CostValue>
1056PermutationCycleHandler<typename GraphType::ArcIndex>*
1062template <
typename GraphType,
typename CostValue>
1063CostValue LinearSumAssignment<GraphType, CostValue>::NewEpsilon(
1064 const CostValue current_epsilon)
const {
1065 return std::max(current_epsilon / alpha_, kMinEpsilon);
1068template <
typename GraphType,
typename CostValue>
1069bool LinearSumAssignment<GraphType, CostValue>::UpdateEpsilon() {
1070 CostValue new_epsilon = NewEpsilon(epsilon_);
1071 slack_relabeling_price_ = PriceChangeBound(epsilon_, new_epsilon,
nullptr);
1072 epsilon_ = new_epsilon;
1073 VLOG(3) <<
"Updated: epsilon_ == " << epsilon_;
1074 VLOG(4) <<
"slack_relabeling_price_ == " << slack_relabeling_price_;
1075 DCHECK_GT(slack_relabeling_price_, 0);
1083template <
typename GraphType,
typename CostValue>
1084inline bool LinearSumAssignment<GraphType, CostValue>::IsActive(
1086 DCHECK_LT(left_node, num_left_nodes_);
1087 return matched_arc_[left_node] == GraphType::kNilArc;
1093template <
typename GraphType,
typename CostValue>
1094inline bool LinearSumAssignment<GraphType, CostValue>::IsActiveForDebugging(
1096 if (node < num_left_nodes_) {
1097 return IsActive(node);
1099 return matched_node_[node] == GraphType::kNilNode;
1103template <
typename GraphType,
typename CostValue>
1105 CostValue>::InitializeActiveNodeContainer() {
1106 DCHECK(active_nodes_->Empty());
1107 for (
const NodeIndex node : BipartiteLeftNodes()) {
1108 if (IsActive(node)) {
1109 active_nodes_->Add(node);
1124template <
typename GraphType,
typename CostValue>
1125void LinearSumAssignment<GraphType, CostValue>::SaturateNegativeArcs() {
1127 for (
const NodeIndex node : BipartiteLeftNodes()) {
1128 if (IsActive(node)) {
1136 matched_arc_[node] = GraphType::kNilArc;
1137 matched_node_[mate] = GraphType::kNilNode;
1143template <
typename GraphType,
typename CostValue>
1144bool LinearSumAssignment<GraphType, CostValue>::DoublePush(
NodeIndex source) {
1145 DCHECK_GT(num_left_nodes_, source);
1146 DCHECK(IsActive(source)) <<
"Node " << source
1147 <<
"must be active (unmatched)!";
1148 ImplicitPriceSummary summary = BestArcAndGap(source);
1149 const ArcIndex best_arc = summary.first;
1154 if (best_arc == GraphType::kNilArc) {
1157 const NodeIndex new_mate = Head(best_arc);
1158 const NodeIndex to_unmatch = matched_node_[new_mate];
1159 if (to_unmatch != GraphType::kNilNode) {
1162 matched_arc_[to_unmatch] = GraphType::kNilArc;
1163 active_nodes_->Add(to_unmatch);
1165 iteration_stats_.double_pushes += 1;
1170 iteration_stats_.pushes += 1;
1172 matched_arc_[source] = best_arc;
1173 matched_node_[new_mate] = source;
1175 iteration_stats_.relabelings += 1;
1176 const CostValue new_price = price_[new_mate] - gap - epsilon_;
1177 price_[new_mate] = new_price;
1178 return new_price >= price_lower_bound_;
1181template <
typename GraphType,
typename CostValue>
1182bool LinearSumAssignment<GraphType, CostValue>::Refine() {
1183 SaturateNegativeArcs();
1184 InitializeActiveNodeContainer();
1185 while (total_excess_ > 0) {
1188 const NodeIndex node = active_nodes_->Get();
1189 if (!DoublePush(node)) {
1197 LOG_IF(DFATAL, total_stats_.refinements > 0)
1198 <<
"Infeasibility detection triggered after first iteration found "
1199 <<
"a feasible assignment!";
1203 DCHECK(active_nodes_->Empty());
1204 iteration_stats_.refinements += 1;
1222template <
typename GraphType,
typename CostValue>
1223inline typename LinearSumAssignment<GraphType, CostValue>::ImplicitPriceSummary
1224LinearSumAssignment<GraphType, CostValue>::BestArcAndGap(
1226 DCHECK(IsActive(left_node))
1227 <<
"Node " << left_node <<
" must be active (unmatched)!";
1228 DCHECK_GT(epsilon_, 0);
1229 const auto arcs = graph_->OutgoingArcs(left_node);
1230 auto arc_it = arcs.begin();
1231 DCHECK(!arcs.empty());
1232 ArcIndex best_arc = *arc_it;
1233 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1239 const CostValue max_gap = slack_relabeling_price_ - epsilon_;
1240 CostValue second_min_partial_reduced_cost =
1241 min_partial_reduced_cost + max_gap;
1242 for (++arc_it; arc_it != arcs.end(); ++arc_it) {
1243 const ArcIndex arc = *arc_it;
1244 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1245 if (partial_reduced_cost < second_min_partial_reduced_cost) {
1246 if (partial_reduced_cost < min_partial_reduced_cost) {
1248 second_min_partial_reduced_cost = min_partial_reduced_cost;
1249 min_partial_reduced_cost = partial_reduced_cost;
1251 second_min_partial_reduced_cost = partial_reduced_cost;
1255 const CostValue gap = std::min<CostValue>(
1256 second_min_partial_reduced_cost - min_partial_reduced_cost, max_gap);
1258 return std::make_pair(best_arc, gap);
1265template <
typename GraphType,
typename CostValue>
1266inline CostValue LinearSumAssignment<GraphType, CostValue>::ImplicitPrice(
1268 DCHECK_GT(num_left_nodes_, left_node);
1269 DCHECK_GT(epsilon_, 0);
1270 const auto arcs = graph_->OutgoingArcs(left_node);
1272 DCHECK(!arcs.empty());
1273 auto arc_it = arcs.begin();
1274 ArcIndex best_arc = *arc_it;
1275 if (best_arc == matched_arc_[left_node]) {
1277 if (arc_it != arcs.end()) {
1281 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1282 if (arc_it == arcs.end()) {
1287 return -(min_partial_reduced_cost + slack_relabeling_price_);
1289 for (++arc_it; arc_it != arcs.end(); ++arc_it) {
1290 const ArcIndex arc = *arc_it;
1291 if (arc != matched_arc_[left_node]) {
1292 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1293 if (partial_reduced_cost < min_partial_reduced_cost) {
1294 min_partial_reduced_cost = partial_reduced_cost;
1298 return -min_partial_reduced_cost;
1302template <
typename GraphType,
typename CostValue>
1303bool LinearSumAssignment<GraphType, CostValue>::AllMatched()
const {
1304 for (
NodeIndex node = 0; node < graph_->num_nodes(); ++node) {
1305 if (IsActiveForDebugging(node)) {
1313template <
typename GraphType,
typename CostValue>
1318 CostValue left_node_price = ImplicitPrice(left_node);
1319 for (
const ArcIndex arc : graph_->OutgoingArcs(left_node)) {
1320 const CostValue reduced_cost = left_node_price + PartialReducedCost(arc);
1325 if (matched_arc_[left_node] == arc) {
1329 if (reduced_cost > epsilon_) {
1335 if (reduced_cost < 0) {
1344template <
typename GraphType,
typename CostValue>
1346 incidence_precondition_satisfied_ =
true;
1350 epsilon_ = std::max(largest_scaled_cost_magnitude_, kMinEpsilon + 1);
1351 VLOG(2) <<
"Largest given cost magnitude: "
1352 << largest_scaled_cost_magnitude_ / cost_scaling_factor_;
1355 for (
NodeIndex node = 0; node < num_left_nodes_; ++node) {
1356 matched_arc_[node] = GraphType::kNilArc;
1357 if (graph_->OutgoingArcs(node).empty()) {
1358 incidence_precondition_satisfied_ =
false;
1363 for (
NodeIndex node = num_left_nodes_; node < graph_->num_nodes(); ++node) {
1365 matched_node_[node] = GraphType::kNilNode;
1367 bool in_range =
true;
1368 double double_price_lower_bound = 0.0;
1370 CostValue old_error_parameter = epsilon_;
1372 new_error_parameter = NewEpsilon(old_error_parameter);
1373 double_price_lower_bound -=
1374 2.0 *
static_cast<double>(PriceChangeBound(
1375 old_error_parameter, new_error_parameter, &in_range));
1376 old_error_parameter = new_error_parameter;
1377 }
while (new_error_parameter != kMinEpsilon);
1378 const double limit =
1379 -
static_cast<double>(std::numeric_limits<CostValue>::max());
1380 if (double_price_lower_bound < limit) {
1382 price_lower_bound_ = -std::numeric_limits<CostValue>::max();
1384 price_lower_bound_ =
static_cast<CostValue>(double_price_lower_bound);
1386 VLOG(4) <<
"price_lower_bound_ == " << price_lower_bound_;
1387 DCHECK_LE(price_lower_bound_, 0);
1389 LOG(WARNING) <<
"Price change bound exceeds range of representable "
1390 <<
"costs; arithmetic overflow is not ruled out and "
1391 <<
"infeasibility might go undetected.";
1396template <
typename GraphType,
typename CostValue>
1397void LinearSumAssignment<GraphType, CostValue>::ReportAndAccumulateStats() {
1398 total_stats_.Add(iteration_stats_);
1399 VLOG(3) <<
"Iteration stats: " << iteration_stats_.StatsString();
1400 iteration_stats_.Clear();
1403template <
typename GraphType,
typename CostValue>
1405 CHECK(graph_ !=
nullptr);
1406 bool ok = graph_->num_nodes() == 2 * num_left_nodes_;
1407 if (!ok)
return false;
1414 ok = ok && incidence_precondition_satisfied_;
1416 while (ok && epsilon_ > kMinEpsilon) {
1417 ok = ok && UpdateEpsilon();
1418 ok = ok && Refine();
1419 ReportAndAccumulateStats();
1421 DCHECK(!ok || AllMatched());
1424 VLOG(1) <<
"Overall stats: " << total_stats_.StatsString();
1428template <
typename GraphType,
typename CostValue>
bool operator()(typename GraphType::ArcIndex a, typename GraphType::ArcIndex b) const
ArcIndexOrderingByTailNode(const GraphType &graph)
CostValueCycleHandler(const CostValueCycleHandler &)=delete
This type is neither copyable nor movable.
CostValueCycleHandler & operator=(const CostValueCycleHandler &)=delete
void SetTempFromIndex(ArcIndexType source) override
void SetIndexFromTemp(ArcIndexType destination) const override
Sets a data element from the temporary.
~CostValueCycleHandler() override
CostValueCycleHandler(std::vector< CostValue > *cost)
void SetIndexFromIndex(ArcIndexType source, ArcIndexType destination) const override
Moves a data element one step along its cycle.
This class does not take ownership of its underlying graph.
void SetArcCost(ArcIndex arc, CostValue cost)
Sets the cost of an arc already present in the given graph.
CostValue GetAssignmentCost(NodeIndex node) const
operations_research::PermutationCycleHandler< typename GraphType::ArcIndex > * ArcAnnotationCycleHandler()
Passes ownership of the cycle handler to the caller.
std::string StatsString() const
NodeIndex NumLeftNodes() const
NodeIndex NumNodes() const
Returns the total number of nodes in the given problem.
NodeIndex Head(ArcIndex arc) const
ArcIndex GetAssignmentArc(NodeIndex left_node) const
Returns the arc through which the given node is matched.
const GraphType & Graph() const
Allows tests, iterators, etc., to inspect our underlying graph.
NodeIndex GetMate(NodeIndex left_node) const
Returns the node to which the given node is matched.
GraphType::NodeIndex NodeIndex
bool EpsilonOptimal() const
Only for debugging.
CostValue GetCost() const
::util::IntegerRange< NodeIndex > BipartiteLeftNodes() const
Returns the range of valid left node indices.
LinearSumAssignment(const LinearSumAssignment &)=delete
This type is neither copyable nor movable.
void SetGraph(const GraphType *graph)
CostValue ArcCost(ArcIndex arc) const
LinearSumAssignment(const GraphType &graph, NodeIndex num_left_nodes)
LinearSumAssignment & operator=(const LinearSumAssignment &)=delete
void SetCostScalingDivisor(CostValue factor)
GraphType::ArcIndex ArcIndex
PermutationCycleHandler(const PermutationCycleHandler &)=delete
OR_DLL ABSL_DECLARE_FLAG(int64_t, assignment_alpha)
In SWIG mode, we don't want anything besides these top-level includes.
BlossomGraph::CostValue CostValue
BlossomGraph::NodeIndex NodeIndex