Google OR-Tools v9.11
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
rooted_tree.h
Go to the documentation of this file.
1// Copyright 2010-2024 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14// Find paths and compute path distances between nodes on a rooted tree.
15//
16// A tree is a connected undirected graph with no cycles. A rooted tree is a
17// directed graph derived from a tree, where a node is designated as the root,
18// and then all edges are directed towards the root.
19//
20// This file provides the class RootedTree, which stores a rooted tree on dense
21// integer nodes a single vector, and a function RootedTreeFromGraph(), which
22// converts the adjacency list of a an undirected tree to a RootedTree.
23
24#ifndef OR_TOOLS_GRAPH_ROOTED_TREE_H_
25#define OR_TOOLS_GRAPH_ROOTED_TREE_H_
26
27#include <algorithm>
28#include <cstddef>
29#include <cstdint>
30#include <string>
31#include <utility>
32#include <vector>
33
34#include "absl/algorithm/container.h"
35#include "absl/log/check.h"
36#include "absl/status/status.h"
37#include "absl/status/statusor.h"
38#include "absl/strings/str_cat.h"
39#include "absl/strings/str_join.h"
40#include "absl/types/span.h"
44
45namespace operations_research {
46
47// A tree is an undirected graph with no cycles, n nodes, and n-1 undirected
48// edges. Consequently, a tree is connected. Given a tree on the nodes [0..n),
49// a RootedTree picks any node to be the root, and then converts all edges into
50// (directed) arcs pointing at the root. Each node has one outgoing edge, so we
51// can store the adjacency list of this directed view of the graph as a single
52// vector of integers with length equal to the number of nodes. At the root
53// index, we store RootedTree::kNullParent=-1, and at every other index, we
54// store the next node towards the root (the parent in the tree).
55//
56// This class is templated on the NodeType, which must be an integer type, e.g.,
57// int or int32_t (signed and unsigned types both work).
58//
59// The following operations are supported:
60// * Path from node to root in O(path length to root)
61// * Lowest Common Ancestor (LCA) of two nodes in O(path length between nodes)
62// * Depth of all nodes in O(num nodes)
63// * Topological sort in O(num nodes)
64// * Path between any two nodes in O(path length between nodes)
65//
66// Users can provide a vector<double> of arc lengths (indexed by source) to get:
67// * Distance from node to root in O(path length to root)
68// * Distance from all nodes to root in O(num nodes)
69// * Distance between any two nodes in O(path length between nodes)
70//
71// Operations on rooted trees are generally more efficient than on adjacency
72// list representations because the entire tree is in one contiguous allocation.
73// There is also an asymptotic advantage on path finding problems.
74//
75// Two methods for finding the LCA are provided. The first requires the depth of
76// every node ahead of time. The second requires a workspace of n bools, all
77// starting at false. These values are modified and restored to false when the
78// LCA computation finishes. In both cases, if the depths/workspace allocation
79// is an O(n) precomputation, then the LCA runs in O(path length).
80// Non-asymptotically, the depth method requires more precomputation, but the
81// LCA is faster and does not require the user to manage mutable state (i.e.,
82// may be better for multi-threaded computation).
83//
84// An operation that is missing is bulk LCA, see
85// https://en.wikipedia.org/wiki/Tarjan%27s_off-line_lowest_common_ancestors_algorithm.
86template <typename NodeType = int32_t>
88 public:
89 static constexpr NodeType kNullParent = static_cast<NodeType>(-1);
90 // Like the constructor but checks that the tree is valid. Uses O(num nodes)
91 // temporary space with O(log(n)) allocations.
92 //
93 // If the input is cyclic, an InvalidArgument error will be returned with
94 // "cycle" as a substring. Further, if error_cycle is not null, it will be
95 // cleared and then set to contain the cycle. We will not modify error cycle
96 // or return an error message containing the string cycle if there is no
97 // cycle. The cycle output will always begin with its smallest element.
98 static absl::StatusOr<RootedTree> Create(
99 NodeType root, std::vector<NodeType> parents,
100 std::vector<NodeType>* error_cycle = nullptr,
101 std::vector<NodeType>* topological_order = nullptr);
102
103 // Like Create(), but data is not validated (UB on bad input).
104 explicit RootedTree(NodeType root, std::vector<NodeType> parents)
105 : root_(root), parents_(std::move(parents)) {}
106
107 // The root node of this rooted tree.
108 NodeType root() const { return root_; }
109
110 // The number of nodes in this rooted tree.
111 NodeType num_nodes() const { return static_cast<NodeType>(parents_.size()); }
112
113 // A vector that holds the parent of each non root node, and kNullParent at
114 // the root.
115 absl::Span<const NodeType> parents() const { return parents_; }
116
117 // Returns the path from `node` to `root()` as a vector of nodes starting with
118 // `node`.
119 std::vector<NodeType> PathToRoot(NodeType node) const;
120
121 // Returns the path from `root()` to `node` as a vector of nodes starting with
122 // `node`.
123 std::vector<NodeType> PathFromRoot(NodeType node) const;
124
125 // Returns the sum of the arc lengths of the arcs in the path from `start` to
126 // `root()`.
127 //
128 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
129 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
130 // The value of `arc_lengths[root()]` is unused.
131 double DistanceToRoot(NodeType start,
132 absl::Span<const double> arc_lengths) const;
133
134 // Returns the path from `start` to `root()` as a vector of nodes starting
135 // with `start`, and the sum of the arc lengths of the arcs in the path.
136 //
137 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
138 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
139 // The value of `arc_lengths[root()]` is unused.
140 std::pair<double, std::vector<NodeType>> DistanceAndPathToRoot(
141 NodeType start, absl::Span<const double> arc_lengths) const;
142
143 // Returns the path from `start` to `end` as a vector of nodes starting with
144 // `start` and ending with `end`.
145 //
146 // `lca` is the lowest common ancestor of `start` and `end`. This can be
147 // computed using LowestCommonAncestorByDepth() or
148 // LowestCommonAncestorByDepth(), both defined on this class.
149 //
150 // Runs in time O(path length).
151 std::vector<NodeType> Path(NodeType start, NodeType end, NodeType lca) const;
152
153 // Returns the sum of the arc lengths of the arcs in the path from `start` to
154 // `end`.
155 //
156 // `lca` is the lowest common ancestor of `start` and `end`. This can be
157 // computed using LowestCommonAncestorByDepth() or
158 // LowestCommonAncestorByDepth(), both defined on this class.
159 //
160 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
161 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
162 // The value of `arc_lengths[root()]` is unused.
163 //
164 // Runs in time O(number of edges connecting start to end).
165 double Distance(NodeType start, NodeType end, NodeType lca,
166 absl::Span<const double> arc_lengths) const;
167
168 // Returns the path from `start` to `end` as a vector of nodes starting with
169 // `start`, and the sum of the arc lengths of the arcs in the path.
170 //
171 // `lca` is the lowest common ancestor of `start` and `end`. This can be
172 // computed using LowestCommonAncestorByDepth() or
173 // LowestCommonAncestorByDepth(), both defined on this class.
174 //
175 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
176 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
177 // The value of `arc_lengths[root()]` is unused.
178 //
179 // Runs in time O(number of edges connecting start to end).
180 std::pair<double, std::vector<NodeType>> DistanceAndPath(
181 NodeType start, NodeType end, NodeType lca,
182 absl::Span<const double> arc_lengths) const;
183
184 // Given a path of nodes, returns the sum of the length of the arcs connecting
185 // them.
186 //
187 // `path` must be a list of nodes in the tree where
188 // path[i+1] == parents()[path[i]].
189 //
190 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
191 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
192 // The value of `arc_lengths[root()]` is unused.
193 double DistanceOfPath(absl::Span<const NodeType> path,
194 absl::Span<const double> arc_lengths) const;
195
196 // Returns a topological ordering of the nodes where the root is first and
197 // every other node appears after its parent.
198 std::vector<NodeType> TopologicalSort() const;
199
200 // Returns the distance of every node from `root()`, if the edge leaving node
201 // i has length costs[i].
202 //
203 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
204 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
205 // The value of `arc_lengths[root()]` is unused.
206 //
207 // If you already have a topological order, prefer
208 // `AllDistances(absl::Span<const double> costs,
209 // absl::Span<const int>& topological_order)`.
210 template <typename T>
211 std::vector<T> AllDistancesToRoot(absl::Span<const T> arc_lengths) const;
212
213 // Returns the distance from every node to root().
214 //
215 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
216 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
217 // The value of `arc_lengths[root()]` is unused.
218 //
219 // `topological_order` must have size equal to `num_nodes()` and start with
220 // `root()`, or else we CHECK fail. It can be any topological over nodes when
221 // the orientation of the arcs from rooting the tree is reversed.
222 template <typename T>
223 std::vector<T> AllDistancesToRoot(
224 absl::Span<const T> arc_lengths,
225 absl::Span<const NodeType> topological_order) const;
226
227 // Returns the distance (arcs to move over) from every node to the root.
228 //
229 // If you already have a topological order, prefer
230 // AllDepths(absl::Span<const NodeType>).
231 std::vector<NodeType> AllDepths() const {
232 return AllDepths(TopologicalSort());
233 }
234
235 // Returns the distance (arcs to move over) from every node to the root.
236 //
237 // `topological_order` must have size equal to `num_nodes()` and start with
238 // `root()`, or else we CHECK fail. It can be any topological over nodes when
239 // the orientation of the arcs from rooting the tree is reversed.
240 std::vector<NodeType> AllDepths(
241 absl::Span<const NodeType> topological_order) const;
242
243 // Returns the lowest common ancestor of n1 and n2.
244 //
245 // `depths` must have size equal to `num_nodes()`, or else we CHECK fail.
246 // Values must be the distance of each node to the root in arcs (see
247 // AllDepths()).
248 NodeType LowestCommonAncestorByDepth(NodeType n1, NodeType n2,
249 absl::Span<const NodeType> depths) const;
250
251 // Returns the lowest common ancestor of n1 and n2.
252 //
253 // `visited_workspace` must be a vector with num_nodes() size, or else we
254 // CHECK fail. All values of `visited_workspace` should be false. It will be
255 // modified and then restored to its starting state.
257 NodeType n1, NodeType n2, std::vector<bool>& visited_workspace) const;
258
259 // Modifies the tree in place to change the root. Runs in
260 // O(path length from root() to new_root).
261 void Evert(NodeType new_root);
262
263 private:
264 static_assert(std::is_integral_v<NodeType>,
265 "NodeType must be an integral type.");
266 static_assert(sizeof(NodeType) <= sizeof(std::size_t),
267 "NodeType cannot be larger than size_t, because NodeType is "
268 "used to index into std::vector.");
269
270 // Returns the number of nodes appended.
271 NodeType AppendToPath(NodeType start, NodeType end,
272 std::vector<NodeType>& path) const;
273
274 // Returns the number of nodes appended.
275 NodeType ReverseAppendToPath(NodeType start, NodeType end,
276 std::vector<NodeType>& path) const;
277
278 // Like AllDistancestoRoot(), but the input arc_lengths is mutated to hold
279 // the output, instead of just returning the output as a new vector.
280 template <typename T>
281 void AllDistancesToRootInPlace(
282 absl::Span<const NodeType> topological_order,
283 absl::Span<T> arc_lengths_in_distances_out) const;
284
285 // Returns the cost of the path from start to end.
286 //
287 // end must be either equal to an or ancestor of start in the tree (otherwise
288 // DCHECK/UB).
289 //
290 // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`.
291 // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail.
292 // The value of `arc_lengths[root()]` is unused.
293 double DistanceOfUpwardPath(NodeType start, NodeType end,
294 absl::Span<const double> arc_lengths) const;
295
296 int root_;
297 std::vector<NodeType> parents_; // kNullParent=-1 if root
298};
299
301// Graph API
303
304// Converts an adjacency list representation of an undirected tree into a rooted
305// tree.
306//
307// Graph must meet the API defined in ortools/graph/graph.h, e.g., StaticGraph
308// or ListGraph. Note that these are directed graph APIs, so they must have both
309// forward and backward arcs for each edge in the tree.
310//
311// Graph must be a tree when viewed as an undirected graph.
312//
313// If topological_order is not null, it is set to a vector with one entry for
314// each node giving a topological ordering over the nodes of the graph, with the
315// root first.
316//
317// If depths is not null, it is set to a vector with one entry for each node,
318// giving the depth in the tree of that node (the root has depth zero).
319template <typename Graph>
320absl::StatusOr<RootedTree<typename Graph::NodeType>> RootedTreeFromGraph(
321 typename Graph::NodeType root, const Graph& graph,
322 std::vector<typename Graph::NodeType>* topological_order = nullptr,
323 std::vector<typename Graph::NodeType>* depths = nullptr);
324
326// Template implementations
328
329namespace internal {
330
331template <typename NodeType>
332bool IsValidParent(const NodeType node, const NodeType num_tree_nodes) {
333 return node == RootedTree<NodeType>::kNullParent ||
334 (node >= NodeType{0} && node < num_tree_nodes);
335}
337template <typename NodeType>
338absl::Status IsValidNode(const NodeType node, const NodeType num_tree_nodes) {
339 if (node < NodeType{0} || node >= num_tree_nodes) {
341 << "nodes must be in [0.." << num_tree_nodes
342 << "), found bad node: " << node;
343 }
344 return absl::OkStatus();
345}
346
347template <typename NodeType>
348std::vector<NodeType> ExtractCycle(absl::Span<const NodeType> parents,
349 const NodeType node_in_cycle) {
350 std::vector<NodeType> cycle;
351 cycle.push_back(node_in_cycle);
352 for (NodeType i = parents[node_in_cycle]; i != node_in_cycle;
353 i = parents[i]) {
355 << "node_in_cycle: " << node_in_cycle
356 << " not in cycle, reached the root";
357 cycle.push_back(i);
358 CHECK_LE(cycle.size(), parents.size())
359 << "node_in_cycle: " << node_in_cycle
360 << " not in cycle, just (transitively) leads to a cycle";
361 }
362 absl::c_rotate(cycle, absl::c_min_element(cycle));
363 cycle.push_back(cycle[0]);
364 return cycle;
365}
366
367template <typename NodeType>
368std::string CycleErrorMessage(absl::Span<const NodeType> cycle) {
369 CHECK_GT(cycle.size(), 0);
370 const NodeType start = cycle[0];
371 std::string cycle_string;
372 if (cycle.size() > 10) {
373 cycle_string = absl::StrCat(
374 absl::StrJoin(absl::MakeConstSpan(cycle).subspan(0, 8), ", "),
375 ", ..., ", start);
376 } else {
377 cycle_string = absl::StrJoin(cycle, ", ");
378 }
379 return absl::StrCat("found cycle of size: ", cycle.size(),
380 " with nodes: ", cycle_string);
381}
382
383// Every element of parents must be in {kNullParent} union [0..parents.size()),
384// otherwise UB.
385template <typename NodeType>
386std::vector<NodeType> CheckForCycle(absl::Span<const NodeType> parents,
387 std::vector<NodeType>* topological_order) {
388 const NodeType n = static_cast<NodeType>(parents.size());
389 if (topological_order != nullptr) {
391 topological_order->reserve(n);
392 }
393 std::vector<bool> visited(n);
394 std::vector<NodeType> dfs_stack;
395 for (NodeType i = 0; i < n; ++i) {
396 if (visited[i]) {
397 continue;
398 }
399 NodeType next = i;
400 while (next != RootedTree<NodeType>::kNullParent && !visited[next]) {
401 dfs_stack.push_back(next);
402 if (dfs_stack.size() > n) {
403 if (topological_order != nullptr) {
404 topological_order->clear();
405 }
406 return ExtractCycle(parents, next);
407 }
408 next = parents[next];
409 DCHECK(IsValidParent(next, n)) << "next: " << next << ", n: " << n;
410 }
411 absl::c_reverse(dfs_stack);
412 for (const NodeType j : dfs_stack) {
413 visited[j] = true;
414 if (topological_order != nullptr) {
415 topological_order->push_back(j);
416 }
417 }
418 dfs_stack.clear();
419 }
420 return {};
421}
422
423} // namespace internal
424
425template <typename NodeType>
426NodeType RootedTree<NodeType>::AppendToPath(const NodeType start,
427 const NodeType end,
428 std::vector<NodeType>& path) const {
429 NodeType num_new = 0;
430 for (NodeType node = start; node != end; node = parents_[node]) {
431 DCHECK_NE(node, kNullParent);
432 path.push_back(node);
433 num_new++;
434 }
435 path.push_back(end);
436 return num_new + 1;
437}
438
439template <typename NodeType>
440NodeType RootedTree<NodeType>::ReverseAppendToPath(
441 NodeType start, NodeType end, std::vector<NodeType>& path) const {
442 NodeType result = AppendToPath(start, end, path);
443 std::reverse(path.end() - result, path.end());
444 return result;
445}
446
447template <typename NodeType>
448double RootedTree<NodeType>::DistanceOfUpwardPath(
449 const NodeType start, const NodeType end,
450 absl::Span<const double> arc_lengths) const {
451 CHECK_EQ(num_nodes(), arc_lengths.size());
452 double distance = 0.0;
453 for (NodeType next = start; next != end; next = parents_[next]) {
454 DCHECK_NE(next, root_);
456 }
457 return distance;
458}
459
460template <typename NodeType>
461absl::StatusOr<RootedTree<NodeType>> RootedTree<NodeType>::Create(
462 const NodeType root, std::vector<NodeType> parents,
463 std::vector<NodeType>* error_cycle,
464 std::vector<NodeType>* topological_order) {
465 const NodeType num_nodes = static_cast<NodeType>(parents.size());
466 RETURN_IF_ERROR(internal::IsValidNode(root, num_nodes)) << "invalid root";
467 if (parents[root] != kNullParent) {
469 << "root should have no parent (-1), but found parent of: "
470 << parents[root];
471 }
472 for (NodeType i = 0; i < num_nodes; ++i) {
473 if (i == root) {
474 continue;
475 }
476 RETURN_IF_ERROR(internal::IsValidNode(parents[i], num_nodes))
477 << "invalid value for parent of node: " << i;
478 }
479 std::vector<NodeType> cycle =
480 internal::CheckForCycle(absl::MakeConstSpan(parents), topological_order);
481 if (!cycle.empty()) {
482 std::string error_message =
483 internal::CycleErrorMessage(absl::MakeConstSpan(cycle));
484 if (error_cycle != nullptr) {
485 *error_cycle = std::move(cycle);
486 }
487 return absl::InvalidArgumentError(std::move(error_message));
488 }
489 return RootedTree(root, std::move(parents));
490}
491
492template <typename NodeType>
493std::vector<NodeType> RootedTree<NodeType>::PathToRoot(
494 const NodeType node) const {
495 std::vector<NodeType> path;
496 for (NodeType next = node; next != root_; next = parents_[next]) {
497 path.push_back(next);
498 }
499 path.push_back(root_);
500 return path;
501}
502
503template <typename NodeType>
504std::vector<NodeType> RootedTree<NodeType>::PathFromRoot(
505 const NodeType node) const {
506 std::vector<NodeType> result = PathToRoot(node);
507 absl::c_reverse(result);
508 return result;
509}
510
511template <typename NodeType>
512std::vector<NodeType> RootedTree<NodeType>::TopologicalSort() const {
513 std::vector<NodeType> result;
514 const std::vector<NodeType> cycle =
515 internal::CheckForCycle(absl::MakeConstSpan(parents_), &result);
516 CHECK(cycle.empty()) << internal::CycleErrorMessage(
517 absl::MakeConstSpan(cycle));
518 return result;
519}
520
521template <typename NodeType>
523 const NodeType start, absl::Span<const double> arc_lengths) const {
524 return DistanceOfUpwardPath(start, root_, arc_lengths);
525}
527template <typename NodeType>
528std::pair<double, std::vector<NodeType>>
530 const NodeType start, absl::Span<const double> arc_lengths) const {
531 CHECK_EQ(num_nodes(), arc_lengths.size());
532 double distance = 0.0;
533 std::vector<NodeType> path;
534 for (NodeType next = start; next != root_; next = parents_[next]) {
535 path.push_back(next);
537 }
538 path.push_back(root_);
539 return {distance, path};
540}
541
542template <typename NodeType>
543std::vector<NodeType> RootedTree<NodeType>::Path(const NodeType start,
544 const NodeType end,
545 const NodeType lca) const {
546 std::vector<NodeType> result;
547 if (start == end) {
548 result.push_back(start);
549 return result;
550 }
551 if (start == lca) {
552 ReverseAppendToPath(end, lca, result);
553 return result;
554 }
555 if (end == lca) {
556 AppendToPath(start, lca, result);
557 return result;
558 }
559 AppendToPath(start, lca, result);
560 result.pop_back(); // Don't include the LCA twice
561 ReverseAppendToPath(end, lca, result);
562 return result;
563}
564
565template <typename NodeType>
567 const NodeType start, const NodeType end, const NodeType lca,
568 absl::Span<const double> arc_lengths) const {
569 return DistanceOfUpwardPath(start, lca, arc_lengths) +
570 DistanceOfUpwardPath(end, lca, arc_lengths);
571}
572
573template <typename NodeType>
574std::pair<double, std::vector<NodeType>> RootedTree<NodeType>::DistanceAndPath(
575 const NodeType start, const NodeType end, const NodeType lca,
576 absl::Span<const double> arc_lengths) const {
577 std::vector<NodeType> path = Path(start, end, lca);
578 const double dist = DistanceOfPath(path, arc_lengths);
579 return {dist, std::move(path)};
580}
581
582template <typename NodeType>
584 absl::Span<const NodeType> path,
585 absl::Span<const double> arc_lengths) const {
586 CHECK_EQ(num_nodes(), arc_lengths.size());
587 double distance = 0.0;
588 for (int i = 0; i + 1 < path.size(); ++i) {
589 if (parents_[path[i]] != path[i + 1]) {
590 distance += arc_lengths[path[i]];
591 } else if (parents_[path[i + 1]] == path[i]) {
592 distance += arc_lengths[path[i + 1]];
593 } else {
594 LOG(FATAL) << "bad edge in path from " << path[i] << " to "
595 << path[i + 1];
596 }
597 }
598 return distance;
599}
600
601template <typename NodeType>
603 const NodeType n1, const NodeType n2,
604 absl::Span<const NodeType> depths) const {
605 CHECK_EQ(num_nodes(), depths.size());
606 const NodeType n = num_nodes();
607 CHECK_OK(internal::IsValidNode(n1, n));
608 CHECK_OK(internal::IsValidNode(n2, n));
609 CHECK_EQ(depths.size(), n);
610 if (n1 == root_ || n2 == root_) {
611 return root_;
612 }
613 if (n1 == n2) {
614 return n1;
615 }
616 NodeType next1 = n1;
617 NodeType next2 = n2;
618 while (depths[next1] > depths[next2]) {
619 next1 = parents_[next1];
620 }
621 while (depths[next2] > depths[next1]) {
622 next2 = parents_[next2];
623 }
624 while (next1 != next2) {
625 next1 = parents_[next1];
626 next2 = parents_[next2];
627 }
628 return next1;
629}
630
631template <typename NodeType>
633 const NodeType n1, const NodeType n2,
634 std::vector<bool>& visited_workspace) const {
635 const NodeType n = num_nodes();
636 CHECK_OK(internal::IsValidNode(n1, n));
637 CHECK_OK(internal::IsValidNode(n2, n));
638 CHECK_EQ(visited_workspace.size(), n);
639 if (n1 == root_ || n2 == root_) {
640 return root_;
641 }
642 if (n1 == n2) {
643 return n1;
644 }
645 NodeType next1 = n1;
646 NodeType next2 = n2;
647 visited_workspace[n1] = true;
648 visited_workspace[n2] = true;
649 NodeType lca = kNullParent;
650 NodeType lca_distance =
651 1; // used only for cleanup purposes, can over estimate
652 while (true) {
653 lca_distance++;
654 if (next1 != root_) {
655 next1 = parents_[next1];
656 if (visited_workspace[next1]) {
657 lca = next1;
658 break;
659 }
660 }
661 if (next2 != root_) {
662 visited_workspace[next1] = true;
663 next2 = parents_[next2];
664 if (visited_workspace[next2]) {
665 lca = next2;
666 break;
667 }
668 visited_workspace[next2] = true;
669 }
670 }
671 CHECK_OK(internal::IsValidNode(lca, n));
672 auto cleanup = [this, lca_distance, &visited_workspace](NodeType next) {
673 for (NodeType i = 0; i < lca_distance && next != kNullParent; ++i) {
674 visited_workspace[next] = false;
675 next = parents_[next];
676 }
677 };
678 cleanup(n1);
679 cleanup(n2);
680 return lca;
681}
682
683template <typename NodeType>
684void RootedTree<NodeType>::Evert(const NodeType new_root) {
685 NodeType previous_node = kNullParent;
686 for (NodeType node = new_root; node != kNullParent;) {
687 NodeType next_node = parents_[node];
688 parents_[node] = previous_node;
689 previous_node = node;
690 node = next_node;
691 }
692 root_ = new_root;
693}
694
695template <typename NodeType>
696template <typename T>
697void RootedTree<NodeType>::AllDistancesToRootInPlace(
698 absl::Span<const NodeType> topological_order,
699 absl::Span<T> arc_lengths_in_distances_out) const {
700 CHECK_EQ(num_nodes(), arc_lengths_in_distances_out.size());
701 CHECK_EQ(num_nodes(), topological_order.size());
702 if (!topological_order.empty()) {
703 CHECK_EQ(topological_order[0], root_);
704 }
705 for (const NodeType node : topological_order) {
706 if (parents_[node] == kNullParent) {
707 arc_lengths_in_distances_out[node] = 0;
708 } else {
709 arc_lengths_in_distances_out[node] +=
710 arc_lengths_in_distances_out[parents_[node]];
711 }
712 }
713}
714
715template <typename NodeType>
716std::vector<NodeType> RootedTree<NodeType>::AllDepths(
717 absl::Span<const NodeType> topological_order) const {
718 std::vector<NodeType> arc_length_in_distance_out(num_nodes(), 1);
719 AllDistancesToRootInPlace(topological_order,
720 absl::MakeSpan(arc_length_in_distance_out));
721 return arc_length_in_distance_out;
722}
723
724template <typename NodeType>
725template <typename T>
727 absl::Span<const T> arc_lengths) const {
728 return AllDistancesToRoot(arc_lengths, TopologicalSort());
729}
731template <typename NodeType>
732template <typename T>
734 absl::Span<const T> arc_lengths,
735 absl::Span<const NodeType> topological_order) const {
736 std::vector<T> distances(arc_lengths.begin(), arc_lengths.end());
737 AllDistancesToRootInPlace(topological_order, absl::MakeSpan(distances));
738 return distances;
739}
740
741template <typename Graph>
742absl::StatusOr<RootedTree<typename Graph::NodeIndex>> RootedTreeFromGraph(
743 const typename Graph::NodeIndex root, const Graph& graph,
744 std::vector<typename Graph::NodeIndex>* const topological_order,
745 std::vector<typename Graph::NodeIndex>* const depths) {
746 using NodeIndex = typename Graph::NodeIndex;
747 const NodeIndex num_nodes = graph.num_nodes();
748 RETURN_IF_ERROR(internal::IsValidNode(root, num_nodes))
749 << "invalid root node";
750 if (topological_order != nullptr) {
751 topological_order->clear();
752 topological_order->reserve(num_nodes);
753 topological_order->push_back(root);
754 }
755 if (depths != nullptr) {
756 depths->clear();
757 depths->resize(num_nodes, 0);
758 }
759 std::vector<NodeIndex> tree(num_nodes, RootedTree<NodeIndex>::kNullParent);
760 auto visited = [&tree, root](const NodeIndex node) {
761 if (node == root) {
762 return true;
763 }
764 return tree[node] != RootedTree<NodeIndex>::kNullParent;
765 };
766 std::vector<NodeIndex> must_search_children = {root};
767 while (!must_search_children.empty()) {
768 NodeIndex next = must_search_children.back();
769 must_search_children.pop_back();
770 for (const NodeIndex neighbor : graph[next]) {
771 if (visited(neighbor)) {
772 if (tree[next] == neighbor) {
773 continue;
774 } else {
775 // NOTE: this will also catch nodes with self loops.
777 << "graph has cycle containing arc from " << next << " to "
778 << neighbor;
779 }
780 }
781 tree[neighbor] = next;
782 if (topological_order != nullptr) {
783 topological_order->push_back(neighbor);
784 }
785 if (depths != nullptr) {
786 (*depths)[neighbor] = (*depths)[next] + 1;
787 }
788 must_search_children.push_back(neighbor);
789 }
790 }
791 for (NodeIndex i = 0; i < num_nodes; ++i) {
792 if (!visited(i)) {
794 << "graph is not connected, no path to " << i;
795 }
796 }
797 return RootedTree<NodeIndex>(root, tree);
798}
799
800} // namespace operations_research
801
802#endif // OR_TOOLS_GRAPH_ROOTED_TREE_H_
#define RETURN_IF_ERROR(expr)
double DistanceToRoot(NodeType start, absl::Span< const double > arc_lengths) const
double Distance(NodeType start, NodeType end, NodeType lca, absl::Span< const double > arc_lengths) const
RootedTree(NodeType root, std::vector< NodeType > parents)
Like Create(), but data is not validated (UB on bad input).
std::vector< NodeType > PathToRoot(NodeType node) const
std::pair< double, std::vector< NodeType > > DistanceAndPath(NodeType start, NodeType end, NodeType lca, absl::Span< const double > arc_lengths) const
absl::Span< const NodeType > parents() const
NodeType LowestCommonAncestorByDepth(NodeType n1, NodeType n2, absl::Span< const NodeType > depths) const
NodeType LowestCommonAncestorBySearch(NodeType n1, NodeType n2, std::vector< bool > &visited_workspace) const
static constexpr NodeType kNullParent
Definition rooted_tree.h:89
std::vector< NodeType > AllDepths() const
std::vector< T > AllDistancesToRoot(absl::Span< const T > arc_lengths) const
NodeType root() const
The root node of this rooted tree.
NodeType num_nodes() const
The number of nodes in this rooted tree.
std::vector< NodeType > TopologicalSort() const
double DistanceOfPath(absl::Span< const NodeType > path, absl::Span< const double > arc_lengths) const
std::vector< NodeType > PathFromRoot(NodeType node) const
static absl::StatusOr< RootedTree > Create(NodeType root, std::vector< NodeType > parents, std::vector< NodeType > *error_cycle=nullptr, std::vector< NodeType > *topological_order=nullptr)
void Evert(NodeType new_root)
std::pair< double, std::vector< NodeType > > DistanceAndPathToRoot(NodeType start, absl::Span< const double > arc_lengths) const
std::vector< NodeType > Path(NodeType start, NodeType end, NodeType lca) const
Block * next
std::vector< NodeIndex > topological_order
std::vector< double > arc_lengths
GraphType graph
std::string CycleErrorMessage(absl::Span< const NodeType > cycle)
bool IsValidParent(const NodeType node, const NodeType num_tree_nodes)
std::vector< NodeType > ExtractCycle(absl::Span< const NodeType > parents, const NodeType node_in_cycle)
std::vector< NodeType > CheckForCycle(absl::Span< const NodeType > parents, std::vector< NodeType > *topological_order)
absl::Status IsValidNode(const NodeType node, const NodeType num_tree_nodes)
In SWIG mode, we don't want anything besides these top-level includes.
absl::StatusOr< RootedTree< typename Graph::NodeType > > RootedTreeFromGraph(typename Graph::NodeType root, const Graph &graph, std::vector< typename Graph::NodeType > *topological_order=nullptr, std::vector< typename Graph::NodeType > *depths=nullptr)
util::ReverseArcStaticGraph Graph
Type of graph to use.
STL namespace.
StatusBuilder InvalidArgumentErrorBuilder()
double distance
std::optional< int64_t > end
int64_t start