Google OR-Tools v9.11
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
radix_sort.h
Go to the documentation of this file.
1// Copyright 2010-2024 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#ifndef OR_TOOLS_ALGORITHMS_RADIX_SORT_H_
15#define OR_TOOLS_ALGORITHMS_RADIX_SORT_H_
16
17// This can be MUCH faster than std::sort() on numerical arrays (int32_t, float,
18// int64_t, double, ..), when the size is ≥8k:
19// ~10x faster on int32_t or float data
20// ~3-5x faster on int64_t or double data
21//
22// Unlike std::sort(), it uses extra, temporary buffers: the radix/count-sort
23// counters, and a copy of the data, i.e. between 1x and 2x your input size.
24//
25// RadixSort() falls back to std::sort() for small sizes, so that you get
26// the best performance in any case.
27//
28// CAVEAT: std::sort() is *very* fast when the array is almost-sorted, or
29// almost reverse-sorted: in this case, RadixSort() can easily be much slower.
30// But the worst-case performance of RadixSort() is much faster than the
31// worst-case performance of std::sort().
32// To be sure, you should benchmark your use case.
33//
34// TODO: it could be even faster than that when the values are in [0..N) for a
35// known value N that's significantly lower than the max integer value.
36
37#include <algorithm>
38#include <cstddef>
39#include <cstdint>
40#include <cstring>
41#include <limits>
42#include <type_traits>
43#include <utility>
44#include <vector>
45
46#include "absl/algorithm/container.h"
47#include "absl/base/casts.h"
48#include "absl/log/check.h"
49#include "absl/types/span.h"
51
52namespace operations_research {
53
54// Sorts an array of int, double, or other numeric types. Up to ~10x faster than
55// std::sort() when size ≥ 8k: go/radix-sort-bench. See file-level comment.
56template <typename T>
57void RadixSort(absl::Span<T> values);
58
59// ADVANCED USAGE: For power users who know which radix_width or num_passes
60// they need, possibly differing from the canonical values used by RadixSort().
61template <typename T, int radix_width, int num_passes>
62void RadixSortTpl(absl::Span<T> values);
63
64// TODO(user): Support arbitrary types with an int() or other numerical getter.
65// TODO(user): Support the user providing already-allocated memory buffers
66// for the radix counts and/or for the temporary vector<T> copy.
67
68// _____________________________________________________________________________
69// The rest of this .h is the implementation of the above templates.
70
71namespace internal {
72// to_uint<T> converts a numerical type T (int, int64_t, float, double, ...) to
73// the unsigned integer of the same bit width.
74template <typename T>
75struct ToUInt : public std::make_unsigned<T> {};
76
77template <>
78struct ToUInt<double> {
79 typedef uint64_t type;
80};
81template <>
82struct ToUInt<float> {
83 typedef uint32_t type;
84};
85
86template <typename T>
87using to_uint = typename ToUInt<T>::type;
88} // namespace internal
89
90// The internal template that does all the work.
91template <typename T, int radix_width, int num_passes>
92void RadixSortTpl(absl::Span<T> values) {
93 // Internally we assume our values are unsigned integers. This works for both
94 // signed integers and IEEE754 floating-point types, with various twists for
95 // negative numbers. In particular, two nonnegative floats compare exactly as
96 // their unsigned integer bitcast do.
97 typedef internal::to_uint<T> U;
98
99 // NOTE(user): We could support sizes > kint32max if needed. We use uint32_t
100 // for sizes, instead of size_t, to spare memory for the large radix widths.
101 // We could use uint64_t for sizes > 4G, but until we need it, just using
102 // uint32_t is simpler. Using smaller types (uint16_t or uint8_t) for smaller
103 // sizes was noticeably slower.
104 DCHECK_LE(values.size(),
105 static_cast<size_t>(std::numeric_limits<uint32_t>::max()));
106 const uint32_t size = values.size();
107
108 // Main Radix/Count-sort counters. Radix sort normally uses several passes,
109 // but to speed things up, we compute all radix counters for all passes at
110 // once in a single initial sweep over the data.
111 //
112 // count[] is actually a 2-dimensional array [num_passes][1 << radix_width],
113 // flattened for performance and in a vector<> because it can be too big for
114 // the stack.
115 std::vector<uint32_t> count(num_passes << radix_width, 0);
116 uint32_t* const count_ptr = count.data();
117
118 // Perform the radix count all at once, in 'parallel' (the CPU should be able
119 // to parallelize the inner loop).
120 constexpr uint32_t kRadixMask = (1 << radix_width) - 1;
121 for (const T value : values) {
122 for (int p = 0; p < num_passes; ++p) {
123 ++count_ptr[(p << radix_width) +
124 ((absl::bit_cast<U>(value) >> (radix_width * p)) &
125 kRadixMask)];
126 }
127 }
128
129 // Convert the counts into offsets via a cumulative sum.
130 uint32_t sum[num_passes] = {};
131 for (int i = 0; i < (1 << radix_width); ++i) {
132 // This inner loop should be parallelizable by the CPU.
133 for (int p = 0; p < num_passes; ++p) {
134 const uint32_t old_sum = sum[p];
135 sum[p] += count_ptr[(p << radix_width) + i];
136 count_ptr[(p << radix_width) + i] = old_sum;
137 }
138 }
139
140 // FIRST-TIME READER: Skip this section, which is only for signed integers:
141 // you can go back to it at the end.
142 //
143 // If T is signed, and if there were any negative numbers, we'll need to
144 // account for that. For floating-point types, we do that at the end of this
145 // function.
146 // For integer types, fortunately, it's easy and fast to do it now: negative
147 // numbers were treated as top-half values in the last radix pass. We can poll
148 // the most significant count[] bucket corresponding to the min negative
149 // number, immediately see if there were any negative numbers, and patch the
150 // last count[] offsets in that case.
151 // Number of bits of the radix in the last pass. E.g. if U is 32 bits,
152 constexpr int kNumBitsInTopRadix =
153 std::numeric_limits<U>::digits - (num_passes - 1) * radix_width;
154 // TODO: remove the if constexpr so that compilation catches the bad cases.
155 if constexpr (std::is_integral_v<T> && std::is_signed_v<T> &&
156 kNumBitsInTopRadix > 0 && kNumBitsInTopRadix <= radix_width) {
157 uint32_t* const last_pass_count =
158 count_ptr + ((num_passes - 1) << radix_width);
159 const uint32_t num_nonnegative_values =
160 last_pass_count[1 << (kNumBitsInTopRadix - 1)];
161 if (num_nonnegative_values != size) {
162 // There are some negative values, and they're sorted last instead of
163 // first, since we considered them as unsigned so far. E.g., with bytes:
164 // 00000000, ..., 01111111, 10000000, ..., 11111111.
165 // Fixing that is easy: we take the 10000000..11111111 chunks and shift
166 // it before all the 00000000..01111111 ones.
167 const uint32_t num_negative_values = size - num_nonnegative_values;
168 for (int i = 0; i < (1 << (kNumBitsInTopRadix - 1)); ++i) {
169 // Shift non-negatives by +num_negative_values...
170 last_pass_count[i] += num_negative_values;
171 // ... and negatives by -num_nonnegative_values.
172 last_pass_count[i + (1 << (kNumBitsInTopRadix - 1))] -=
173 num_nonnegative_values;
174 }
175 }
176 }
177
178 // Perform the radix sort, using a temporary buffer.
179 std::vector<T> tmp(size);
180 T* from = values.data();
181 T* to = tmp.data();
182 int radix = 0;
183 for (int pass = 0; pass < num_passes; ++pass, radix += radix_width) {
184 uint32_t* const cur_count_ptr = count_ptr + (pass << radix_width);
185 const T* const from_end = from + size;
186 for (T* ptr = from; ptr < from_end; ++ptr) {
187 to[cur_count_ptr[(absl::bit_cast<U>(*ptr) >> radix) & kRadixMask]++] =
188 *ptr;
189 }
190 std::swap(from, to);
191 }
192
193 // FIRST-TIME READER: Skip this section, which is only for negative floats.
194 // We fix mis-sorted negative floating-point numbers here.
195 if constexpr (!std::is_integral_v<T> && std::is_signed_v<T> &&
196 kNumBitsInTopRadix > 0 && kNumBitsInTopRadix <= radix_width) {
197 uint32_t* const last_pass_count =
198 count_ptr + ((num_passes - 1) << radix_width);
199 const uint32_t num_nonnegative_values =
200 last_pass_count[(1 << (kNumBitsInTopRadix - 1)) - 1];
201 if (num_nonnegative_values != size) {
202 // Negative floating-point numbers are sorted exactly in the reverse
203 // order. Unlike for integers, we need to std::reverse() them, and also
204 // shift them back before the positive ones.
205 const uint32_t num_negative_values = size - num_nonnegative_values;
206 if constexpr (num_passes % 2) {
207 // If we swapped an odd number of times, we're lucky: we don't need to
208 // make an extra copy.
209 std::memcpy(values.data() + num_negative_values, tmp.data(),
210 num_nonnegative_values * sizeof(T));
211 // TODO(user): See if this is faster than memcpy + std::reverse().
212 DCHECK_EQ(from, tmp.data());
213 for (uint32_t i = 0; i < num_negative_values; ++i) {
214 values[i] = from[size - 1 - i]; // from[] = tmp[]
215 }
216 } else {
217 // We can't move + reverse in-place, so we need the temporary buffer.
218 // First, we copy all negative numbers to the temporary buffer.
219 std::memcpy(tmp.data(), values.data() + num_nonnegative_values,
220 num_negative_values * sizeof(T));
221 // Then we shift the nonnegative.
222 // TODO(user): See if memcpy everything + memcpy here is faster than
223 // memmove().
224 std::memmove(values.data() + num_negative_values, values.data(),
225 num_nonnegative_values * sizeof(T));
226 DCHECK_EQ(to, tmp.data());
227 for (uint32_t i = 0; i < num_negative_values; ++i) {
228 values[i] = to[num_negative_values - 1 - i]; // to[] = tmp[].
229 }
230 }
231 // If there were negative floats, we've done our work and are done. Else
232 // we still may need to move the data from the temp buffer to 'values'.
233 return;
234 }
235 }
236
237 // If we swapped an odd number of times, copy tmp[] onto values[].
238 if constexpr (num_passes % 2) {
239 std::memcpy(values.data(), from, size * sizeof(T));
240 }
241}
242
243// TODO(user): Expose an API that takes the "max value" as argument, for
244// users who want to take advantage of that knowledge to reduce the number of
245// passes.
246template <typename T>
247void RadixSort(absl::Span<T> values) {
248 switch (sizeof(T)) {
249 case 1:
250 if (values.size() < 300) {
251 absl::c_sort(values);
252 } else {
253 RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/1>(values);
254 }
255 return;
256 case 2:
257 if (values.size() < 300) {
258 absl::c_sort(values);
259 } else {
260 RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/2>(values);
261 }
262 return;
263 case 4:
264 if (values.size() < 300) {
265 absl::c_sort(values);
266 } else if (values.size() < 1000) {
267 RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/4>(values);
268 } else if (values.size() < 2'500'000) {
269 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/3>(values);
270 } else {
271 RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/2>(values);
272 }
273 return;
274 case 8:
275 if (values.size() < 5000) {
276 absl::c_sort(values);
277 } else if (values.size() < 1'500'000) {
278 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/6>(values);
279 } else {
280 RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/4>(values);
281 }
282 return;
283 }
284 LOG(DFATAL) << "RadixSort() called with unsupported value type";
285 absl::c_sort(values);
286}
287
288} // namespace operations_research
289
290#endif // OR_TOOLS_ALGORITHMS_RADIX_SORT_H_
IntegerValue size
int64_t value
typename ToUInt< T >::type to_uint
Definition radix_sort.h:87
In SWIG mode, we don't want anything besides these top-level includes.
void RadixSortTpl(absl::Span< T > values)
The internal template that does all the work.
Definition radix_sort.h:92
void RadixSort(absl::Span< T > values)
Definition radix_sort.h:247
trees with all degrees equal to