24#ifndef OR_TOOLS_GRAPH_ROOTED_TREE_H_
25#define OR_TOOLS_GRAPH_ROOTED_TREE_H_
34#include "absl/algorithm/container.h"
35#include "absl/log/check.h"
36#include "absl/status/status.h"
37#include "absl/status/statusor.h"
38#include "absl/strings/str_cat.h"
39#include "absl/strings/str_join.h"
40#include "absl/types/span.h"
86template <
typename NodeType =
int32_t>
89 static constexpr NodeType
kNullParent =
static_cast<NodeType
>(-1);
98 static absl::StatusOr<RootedTree>
Create(
100 std::vector<NodeType>* error_cycle =
nullptr,
108 NodeType
root()
const {
return root_; }
111 NodeType
num_nodes()
const {
return static_cast<NodeType
>(parents_.size()); }
115 absl::Span<const NodeType>
parents()
const {
return parents_; }
119 std::vector<NodeType>
PathToRoot(NodeType node)
const;
123 std::vector<NodeType>
PathFromRoot(NodeType node)
const;
151 std::vector<NodeType>
Path(NodeType
start, NodeType
end, NodeType lca)
const;
181 NodeType
start, NodeType
end, NodeType lca,
210 template <
typename T>
222 template <
typename T>
249 absl::Span<const NodeType> depths)
const;
257 NodeType n1, NodeType n2, std::vector<bool>& visited_workspace)
const;
261 void Evert(NodeType new_root);
264 static_assert(std::is_integral_v<NodeType>,
265 "NodeType must be an integral type.");
266 static_assert(
sizeof(NodeType) <=
sizeof(std::size_t),
267 "NodeType cannot be larger than size_t, because NodeType is "
268 "used to index into std::vector.");
271 NodeType AppendToPath(NodeType
start, NodeType
end,
272 std::vector<NodeType>& path)
const;
275 NodeType ReverseAppendToPath(NodeType
start, NodeType
end,
276 std::vector<NodeType>& path)
const;
280 template <
typename T>
281 void AllDistancesToRootInPlace(
283 absl::Span<T> arc_lengths_in_distances_out)
const;
293 double DistanceOfUpwardPath(NodeType
start, NodeType
end,
297 std::vector<NodeType> parents_;
319template <
typename Graph>
321 typename Graph::NodeType root,
const Graph&
graph,
323 std::vector<typename Graph::NodeType>* depths =
nullptr);
331template <
typename NodeType>
332bool IsValidParent(
const NodeType node,
const NodeType num_tree_nodes) {
334 (node >= NodeType{0} && node < num_tree_nodes);
337template <
typename NodeType>
338absl::Status
IsValidNode(
const NodeType node,
const NodeType num_tree_nodes) {
339 if (node < NodeType{0} || node >= num_tree_nodes) {
341 <<
"nodes must be in [0.." << num_tree_nodes
342 <<
"), found bad node: " << node;
344 return absl::OkStatus();
347template <
typename NodeType>
348std::vector<NodeType>
ExtractCycle(absl::Span<const NodeType> parents,
349 const NodeType node_in_cycle) {
350 std::vector<NodeType> cycle;
351 cycle.push_back(node_in_cycle);
352 for (NodeType i = parents[node_in_cycle]; i != node_in_cycle;
355 <<
"node_in_cycle: " << node_in_cycle
356 <<
" not in cycle, reached the root";
358 CHECK_LE(cycle.size(), parents.size())
359 <<
"node_in_cycle: " << node_in_cycle
360 <<
" not in cycle, just (transitively) leads to a cycle";
362 absl::c_rotate(cycle, absl::c_min_element(cycle));
363 cycle.push_back(cycle[0]);
367template <
typename NodeType>
369 CHECK_GT(cycle.size(), 0);
370 const NodeType
start = cycle[0];
371 std::string cycle_string;
372 if (cycle.size() > 10) {
373 cycle_string = absl::StrCat(
374 absl::StrJoin(absl::MakeConstSpan(cycle).subspan(0, 8),
", "),
377 cycle_string = absl::StrJoin(cycle,
", ");
379 return absl::StrCat(
"found cycle of size: ", cycle.size(),
380 " with nodes: ", cycle_string);
385template <
typename NodeType>
386std::vector<NodeType>
CheckForCycle(absl::Span<const NodeType> parents,
388 const NodeType n =
static_cast<NodeType
>(parents.size());
393 std::vector<bool> visited(n);
394 std::vector<NodeType> dfs_stack;
395 for (NodeType i = 0; i < n; ++i) {
401 dfs_stack.push_back(
next);
402 if (dfs_stack.size() > n) {
411 absl::c_reverse(dfs_stack);
412 for (
const NodeType j : dfs_stack) {
425template <
typename NodeType>
426NodeType RootedTree<NodeType>::AppendToPath(
const NodeType
start,
428 std::vector<NodeType>& path)
const {
429 NodeType num_new = 0;
430 for (NodeType node =
start; node !=
end; node = parents_[node]) {
431 DCHECK_NE(node, kNullParent);
432 path.push_back(node);
439template <
typename NodeType>
440NodeType RootedTree<NodeType>::ReverseAppendToPath(
441 NodeType
start, NodeType
end, std::vector<NodeType>& path)
const {
442 NodeType result = AppendToPath(
start,
end, path);
443 std::reverse(path.end() - result, path.end());
447template <
typename NodeType>
448double RootedTree<NodeType>::DistanceOfUpwardPath(
449 const NodeType
start,
const NodeType
end,
454 DCHECK_NE(
next, root_);
460template <
typename NodeType>
462 const NodeType root, std::vector<NodeType> parents,
463 std::vector<NodeType>* error_cycle,
465 const NodeType num_nodes =
static_cast<NodeType
>(parents.size());
467 if (parents[root] != kNullParent) {
469 <<
"root should have no parent (-1), but found parent of: "
472 for (NodeType i = 0; i < num_nodes; ++i) {
477 <<
"invalid value for parent of node: " << i;
479 std::vector<NodeType> cycle =
481 if (!cycle.empty()) {
482 std::string error_message =
484 if (error_cycle !=
nullptr) {
485 *error_cycle = std::move(cycle);
487 return absl::InvalidArgumentError(std::move(error_message));
489 return RootedTree(root, std::move(parents));
492template <
typename NodeType>
494 const NodeType node)
const {
495 std::vector<NodeType> path;
499 path.push_back(root_);
503template <
typename NodeType>
505 const NodeType node)
const {
506 std::vector<NodeType> result = PathToRoot(node);
507 absl::c_reverse(result);
511template <
typename NodeType>
513 std::vector<NodeType> result;
514 const std::vector<NodeType> cycle =
517 absl::MakeConstSpan(cycle));
521template <
typename NodeType>
527template <
typename NodeType>
528std::pair<double, std::vector<NodeType>>
533 std::vector<NodeType> path;
535 path.push_back(
next);
538 path.push_back(root_);
542template <
typename NodeType>
545 const NodeType lca)
const {
546 std::vector<NodeType> result;
548 result.push_back(
start);
552 ReverseAppendToPath(
end, lca, result);
556 AppendToPath(
start, lca, result);
559 AppendToPath(
start, lca, result);
561 ReverseAppendToPath(
end, lca, result);
565template <
typename NodeType>
567 const NodeType
start,
const NodeType
end,
const NodeType lca,
573template <
typename NodeType>
575 const NodeType
start,
const NodeType
end,
const NodeType lca,
577 std::vector<NodeType> path = Path(
start,
end, lca);
579 return {dist, std::move(path)};
582template <
typename NodeType>
584 absl::Span<const NodeType> path,
588 for (
int i = 0; i + 1 < path.size(); ++i) {
589 if (parents_[path[i]] != path[i + 1]) {
591 }
else if (parents_[path[i + 1]] == path[i]) {
594 LOG(FATAL) <<
"bad edge in path from " << path[i] <<
" to "
601template <
typename NodeType>
603 const NodeType n1,
const NodeType n2,
604 absl::Span<const NodeType> depths)
const {
605 CHECK_EQ(num_nodes(), depths.size());
606 const NodeType n = num_nodes();
609 CHECK_EQ(depths.size(), n);
610 if (n1 == root_ || n2 == root_) {
618 while (depths[next1] > depths[next2]) {
619 next1 = parents_[next1];
621 while (depths[next2] > depths[next1]) {
622 next2 = parents_[next2];
624 while (next1 != next2) {
625 next1 = parents_[next1];
626 next2 = parents_[next2];
631template <
typename NodeType>
633 const NodeType n1,
const NodeType n2,
634 std::vector<bool>& visited_workspace)
const {
635 const NodeType n = num_nodes();
638 CHECK_EQ(visited_workspace.size(), n);
639 if (n1 == root_ || n2 == root_) {
647 visited_workspace[n1] =
true;
648 visited_workspace[n2] =
true;
649 NodeType lca = kNullParent;
650 NodeType lca_distance =
654 if (next1 != root_) {
655 next1 = parents_[next1];
656 if (visited_workspace[next1]) {
661 if (next2 != root_) {
662 visited_workspace[next1] =
true;
663 next2 = parents_[next2];
664 if (visited_workspace[next2]) {
668 visited_workspace[next2] =
true;
672 auto cleanup = [
this, lca_distance, &visited_workspace](NodeType
next) {
673 for (NodeType i = 0;
i < lca_distance &&
next != kNullParent; ++
i) {
674 visited_workspace[
next] =
false;
683template <
typename NodeType>
685 NodeType previous_node = kNullParent;
686 for (NodeType node = new_root; node != kNullParent;) {
687 NodeType next_node = parents_[node];
688 parents_[node] = previous_node;
689 previous_node = node;
695template <
typename NodeType>
697void RootedTree<NodeType>::AllDistancesToRootInPlace(
699 absl::Span<T> arc_lengths_in_distances_out)
const {
700 CHECK_EQ(num_nodes(), arc_lengths_in_distances_out.size());
706 if (parents_[node] == kNullParent) {
707 arc_lengths_in_distances_out[node] = 0;
709 arc_lengths_in_distances_out[node] +=
710 arc_lengths_in_distances_out[parents_[node]];
715template <
typename NodeType>
718 std::vector<NodeType> arc_length_in_distance_out(num_nodes(), 1);
720 absl::MakeSpan(arc_length_in_distance_out));
721 return arc_length_in_distance_out;
724template <
typename NodeType>
728 return AllDistancesToRoot(
arc_lengths, TopologicalSort());
731template <
typename NodeType>
741template <
typename Graph>
745 std::vector<typename Graph::NodeIndex>*
const depths) {
749 <<
"invalid root node";
755 if (depths !=
nullptr) {
757 depths->resize(num_nodes, 0);
760 auto visited = [&tree, root](
const NodeIndex node) {
766 std::vector<NodeIndex> must_search_children = {root};
767 while (!must_search_children.empty()) {
769 must_search_children.pop_back();
771 if (visited(neighbor)) {
772 if (tree[
next] == neighbor) {
777 <<
"graph has cycle containing arc from " <<
next <<
" to "
781 tree[neighbor] =
next;
785 if (depths !=
nullptr) {
786 (*depths)[neighbor] = (*depths)[
next] + 1;
788 must_search_children.push_back(neighbor);
794 <<
"graph is not connected, no path to " <<
i;
797 return RootedTree<NodeIndex>(root, tree);
#define RETURN_IF_ERROR(expr)
double DistanceToRoot(NodeType start, absl::Span< const double > arc_lengths) const
double Distance(NodeType start, NodeType end, NodeType lca, absl::Span< const double > arc_lengths) const
RootedTree(NodeType root, std::vector< NodeType > parents)
Like Create(), but data is not validated (UB on bad input).
std::vector< NodeType > PathToRoot(NodeType node) const
std::pair< double, std::vector< NodeType > > DistanceAndPath(NodeType start, NodeType end, NodeType lca, absl::Span< const double > arc_lengths) const
absl::Span< const NodeType > parents() const
NodeType LowestCommonAncestorByDepth(NodeType n1, NodeType n2, absl::Span< const NodeType > depths) const
NodeType LowestCommonAncestorBySearch(NodeType n1, NodeType n2, std::vector< bool > &visited_workspace) const
static constexpr NodeType kNullParent
std::vector< NodeType > AllDepths() const
std::vector< T > AllDistancesToRoot(absl::Span< const T > arc_lengths) const
NodeType root() const
The root node of this rooted tree.
NodeType num_nodes() const
The number of nodes in this rooted tree.
std::vector< NodeType > TopologicalSort() const
double DistanceOfPath(absl::Span< const NodeType > path, absl::Span< const double > arc_lengths) const
std::vector< NodeType > PathFromRoot(NodeType node) const
static absl::StatusOr< RootedTree > Create(NodeType root, std::vector< NodeType > parents, std::vector< NodeType > *error_cycle=nullptr, std::vector< NodeType > *topological_order=nullptr)
void Evert(NodeType new_root)
std::pair< double, std::vector< NodeType > > DistanceAndPathToRoot(NodeType start, absl::Span< const double > arc_lengths) const
std::vector< NodeType > Path(NodeType start, NodeType end, NodeType lca) const
std::vector< NodeIndex > topological_order
std::vector< double > arc_lengths
std::string CycleErrorMessage(absl::Span< const NodeType > cycle)
bool IsValidParent(const NodeType node, const NodeType num_tree_nodes)
std::vector< NodeType > ExtractCycle(absl::Span< const NodeType > parents, const NodeType node_in_cycle)
std::vector< NodeType > CheckForCycle(absl::Span< const NodeType > parents, std::vector< NodeType > *topological_order)
absl::Status IsValidNode(const NodeType node, const NodeType num_tree_nodes)
In SWIG mode, we don't want anything besides these top-level includes.
absl::StatusOr< RootedTree< typename Graph::NodeType > > RootedTreeFromGraph(typename Graph::NodeType root, const Graph &graph, std::vector< typename Graph::NodeType > *topological_order=nullptr, std::vector< typename Graph::NodeType > *depths=nullptr)
util::ReverseArcStaticGraph Graph
Type of graph to use.
StatusBuilder InvalidArgumentErrorBuilder()
std::optional< int64_t > end