Google OR-Tools v9.14
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
radix_sort.h
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#ifndef OR_TOOLS_ALGORITHMS_RADIX_SORT_H_
15#define OR_TOOLS_ALGORITHMS_RADIX_SORT_H_
16
17// This can be MUCH faster than std::sort() on numerical arrays (int32_t, float,
18// int64_t, double, ..), when the size is ≥8k:
19// ~10x faster on int32_t or float data
20// ~3-5x faster on int64_t or double data
21//
22// Unlike std::sort(), it uses extra, temporary buffers: the radix/count-sort
23// counters, and a copy of the data, i.e. between 1x and 2x your input size.
24//
25// RadixSort() falls back to std::sort() for small sizes, so that you get
26// the best performance in any case.
27//
28// CAVEAT: std::sort() is *very* fast when the array is almost-sorted, or
29// almost reverse-sorted: in this case, RadixSort() can easily be much slower.
30// But the worst-case performance of RadixSort() is much faster than the
31// worst-case performance of std::sort().
32// To be sure, you should benchmark your use case.
33
34#include <algorithm>
35#include <cstddef>
36#include <cstdint>
37#include <cstring>
38#include <limits>
39#include <type_traits>
40#include <utility>
41#include <vector>
42
43#include "absl/algorithm/container.h"
44#include "absl/base/casts.h"
45#include "absl/base/log_severity.h"
46#include "absl/log/check.h"
47#include "absl/log/log.h"
48#include "absl/numeric/bits.h"
49#include "absl/types/span.h"
50
51namespace operations_research {
52
53// Sorts an array of int, double, or other numeric types. Up to ~10x faster than
54// std::sort() when size ≥ 8k: go/radix-sort-bench. See file-level comment.
55template <typename T>
56void RadixSort(
57 absl::Span<T> values,
58 // ADVANCED USAGE: if you're sorting nonnegative integers, and suspect that
59 // their values use less bits than their full bit width, you may improve
60 // performance by setting `num_bits` to a lower value, for example
61 // NumBitsForZeroTo(max_value). It might even be faster to scan the values
62 // once just to do that, e.g., RadixSort(values,
63 // NumBitsForZeroTo(*absl::c_max_element(values)));
64 int num_bits = sizeof(T) * 8);
65
66template <typename T>
67int NumBitsForZeroTo(T max_value);
68
69// ADVANCED USAGE: For power users who know which radix_width or num_passes
70// they need, possibly differing from the canonical values used by RadixSort().
71template <typename T, int radix_width, int num_passes>
72void RadixSortTpl(absl::Span<T> values);
73
74// TODO(user): Support the user providing already-allocated memory buffers
75// for the radix counts and/or for the temporary vector<T> copy.
76
77// _____________________________________________________________________________
78// The rest of this .h is the implementation of the above templates.
79
80namespace internal {
81// to_uint<T> converts a numerical type T (int, int64_t, float, double, ...) to
82// the unsigned integer of the same bit width.
83template <typename T>
84struct ToUInt : public std::make_unsigned<T> {};
85
86template <>
87struct ToUInt<double> {
88 typedef uint64_t type;
89};
90template <>
91struct ToUInt<float> {
92 typedef uint32_t type;
93};
94
95template <typename T>
96using to_uint = typename ToUInt<T>::type;
97} // namespace internal
98
99// The internal template that does all the work.
100template <typename T, int radix_width, int num_passes>
101void RadixSortTpl(absl::Span<T> values) {
102 // Internally we assume our values are unsigned integers. This works for both
103 // signed integers and IEEE754 floating-point types, with various twists for
104 // negative numbers. In particular, two nonnegative floats compare exactly as
105 // their unsigned integer bitcast do.
106 typedef internal::to_uint<T> U;
107
108 // NOTE(user): We could support sizes > kint32max if needed. We use uint32_t
109 // for sizes, instead of size_t, to spare memory for the large radix widths.
110 // We could use uint64_t for sizes > 4G, but until we need it, just using
111 // uint32_t is simpler. Using smaller types (uint16_t or uint8_t) for smaller
112 // sizes was noticeably slower.
113 DCHECK_LE(values.size(),
114 static_cast<size_t>(std::numeric_limits<uint32_t>::max()));
115 const uint32_t size = values.size();
116
117 // Main Radix/Count-sort counters. Radix sort normally uses several passes,
118 // but to speed things up, we compute all radix counters for all passes at
119 // once in a single initial sweep over the data.
120 //
121 // count[] is actually a 2-dimensional array [num_passes][1 << radix_width],
122 // flattened for performance and in a vector<> because it can be too big for
123 // the stack.
124 std::vector<uint32_t> count(num_passes << radix_width, 0);
125 uint32_t* const count_ptr = count.data();
126
127 // Perform the radix count all at once, in 'parallel' (the CPU should be able
128 // to parallelize the inner loop).
129 constexpr uint32_t kRadixMask = (1 << radix_width) - 1;
130 for (const T value : values) {
131 for (int p = 0; p < num_passes; ++p) {
132 ++count_ptr[(p << radix_width) +
133 ((absl::bit_cast<U>(value) >> (radix_width * p)) &
134 kRadixMask)];
135 }
136 }
137
138 // Convert the counts into offsets via a cumulative sum.
139 uint32_t sum[num_passes] = {};
140 for (int i = 0; i < (1 << radix_width); ++i) {
141 // This inner loop should be parallelizable by the CPU.
142 for (int p = 0; p < num_passes; ++p) {
143 const uint32_t old_sum = sum[p];
144 sum[p] += count_ptr[(p << radix_width) + i];
145 count_ptr[(p << radix_width) + i] = old_sum;
146 }
147 }
148
149 // FIRST-TIME READER: Skip this section, which is only for signed integers:
150 // you can go back to it at the end.
151 //
152 // If T is signed, and if there were any negative numbers, we'll need to
153 // account for that. For floating-point types, we do that at the end of this
154 // function.
155 // For integer types, fortunately, it's easy and fast to do it now: negative
156 // numbers were treated as top-half values in the last radix pass. We can poll
157 // the most significant count[] bucket corresponding to the min negative
158 // number, immediately see if there were any negative numbers, and patch the
159 // last count[] offsets in that case.
160 // Number of bits of the radix in the last pass. E.g. if U is 32 bits,
161 constexpr int kNumBitsInTopRadix =
162 std::numeric_limits<U>::digits - (num_passes - 1) * radix_width;
163 // TODO: remove the if constexpr so that compilation catches the bad cases.
164 if constexpr (std::is_integral_v<T> && std::is_signed_v<T> &&
165 kNumBitsInTopRadix > 0 && kNumBitsInTopRadix <= radix_width) {
166 uint32_t* const last_pass_count =
167 count_ptr + ((num_passes - 1) << radix_width);
168 const uint32_t num_nonnegative_values =
169 last_pass_count[1 << (kNumBitsInTopRadix - 1)];
170 if (num_nonnegative_values != size) {
171 // There are some negative values, and they're sorted last instead of
172 // first, since we considered them as unsigned so far. E.g., with bytes:
173 // 00000000, ..., 01111111, 10000000, ..., 11111111.
174 // Fixing that is easy: we take the 10000000..11111111 chunks and shift
175 // it before all the 00000000..01111111 ones.
176 const uint32_t num_negative_values = size - num_nonnegative_values;
177 for (int i = 0; i < (1 << (kNumBitsInTopRadix - 1)); ++i) {
178 // Shift non-negatives by +num_negative_values...
179 last_pass_count[i] += num_negative_values;
180 // ... and negatives by -num_nonnegative_values.
181 last_pass_count[i + (1 << (kNumBitsInTopRadix - 1))] -=
182 num_nonnegative_values;
183 }
184 }
185 }
186
187 // Perform the radix sort, using a temporary buffer.
188 std::vector<T> tmp(size);
189 T* from = values.data();
190 T* to = tmp.data();
191 int radix = 0;
192 for (int pass = 0; pass < num_passes; ++pass, radix += radix_width) {
193 uint32_t* const cur_count_ptr = count_ptr + (pass << radix_width);
194 const T* const from_end = from + size;
195 for (T* ptr = from; ptr < from_end; ++ptr) {
196 to[cur_count_ptr[(absl::bit_cast<U>(*ptr) >> radix) & kRadixMask]++] =
197 *ptr;
198 }
199 std::swap(from, to);
200 }
201
202 // FIRST-TIME READER: Skip this section, which is only for negative floats.
203 // We fix mis-sorted negative floating-point numbers here.
204 if constexpr (!std::is_integral_v<T> && std::is_signed_v<T> &&
205 kNumBitsInTopRadix > 0 && kNumBitsInTopRadix <= radix_width) {
206 uint32_t* const last_pass_count =
207 count_ptr + ((num_passes - 1) << radix_width);
208 const uint32_t num_nonnegative_values =
209 last_pass_count[(1 << (kNumBitsInTopRadix - 1)) - 1];
210 if (num_nonnegative_values != size) {
211 // Negative floating-point numbers are sorted exactly in the reverse
212 // order. Unlike for integers, we need to std::reverse() them, and also
213 // shift them back before the positive ones.
214 const uint32_t num_negative_values = size - num_nonnegative_values;
215 if constexpr (num_passes % 2) {
216 // If we swapped an odd number of times, we're lucky: we don't need to
217 // make an extra copy.
218 std::memcpy(values.data() + num_negative_values, tmp.data(),
219 num_nonnegative_values * sizeof(T));
220 // TODO(user): See if this is faster than memcpy + std::reverse().
221 DCHECK_EQ(from, tmp.data());
222 for (uint32_t i = 0; i < num_negative_values; ++i) {
223 values[i] = from[size - 1 - i]; // from[] = tmp[]
224 }
225 } else {
226 // We can't move + reverse in-place, so we need the temporary buffer.
227 // First, we copy all negative numbers to the temporary buffer.
228 std::memcpy(tmp.data(), values.data() + num_nonnegative_values,
229 num_negative_values * sizeof(T));
230 // Then we shift the nonnegative.
231 // TODO(user): See if memcpy everything + memcpy here is faster than
232 // memmove().
233 std::memmove(values.data() + num_negative_values, values.data(),
234 num_nonnegative_values * sizeof(T));
235 DCHECK_EQ(to, tmp.data());
236 for (uint32_t i = 0; i < num_negative_values; ++i) {
237 values[i] = to[num_negative_values - 1 - i]; // to[] = tmp[].
238 }
239 }
240 // If there were negative floats, we've done our work and are done. Else
241 // we still may need to move the data from the temp buffer to 'values'.
242 return;
243 }
244 }
245
246 // If we swapped an odd number of times, copy tmp[] onto values[].
247 if constexpr (num_passes % 2) {
248 std::memcpy(values.data(), from, size * sizeof(T));
249 }
250}
251
252template <typename T>
253int NumBitsForZeroTo(T max_value) {
254 if constexpr (!std::is_integral_v<T>) {
255 return sizeof(T) * 8;
256 } else {
257 using U = std::make_unsigned_t<T>;
258 DCHECK_GE(max_value, 0);
259 return std::numeric_limits<U>::digits - absl::countl_zero<U>(max_value);
260 }
261}
262
263#ifdef NDEBUG
264const bool DEBUG_MODE = false;
265#else
266const bool DEBUG_MODE = true;
267#endif
268
269template <typename T>
270void RadixSort(absl::Span<T> values, int num_bits) {
271 // Debug-check that num_bits is valid w.r.t. the values given.
272 if constexpr (DEBUG_MODE) {
273 if constexpr (!std::is_integral_v<T>) {
274 DCHECK_EQ(num_bits, sizeof(T) * 8);
275 } else if (!values.empty()) {
276 auto minmax_it = absl::c_minmax_element(values);
277 const T min_val = *minmax_it.first;
278 const T max_val = *minmax_it.second;
279 if (num_bits == 0) {
280 DCHECK_EQ(max_val, 0);
281 } else {
282 using U = std::make_unsigned_t<T>;
283 // We only shift by num_bits - 1, to avoid to potentially shift by the
284 // entire bit width, which would be undefined behavior.
285 DCHECK_LE(static_cast<U>(max_val) >> (num_bits - 1), 1);
286 DCHECK_LE(static_cast<U>(min_val) >> (num_bits - 1), 1);
287 }
288 }
289 }
290
291 // This shortcut here is important to have early, guarded by as few "if"
292 // branches as possible, for the use case where the array is very small.
293 // For larger arrays below, the overhead of a few "if" is negligible.
294 if (values.size() < 300) {
295 absl::c_sort(values);
296 return;
297 }
298
299 // TODO(user): More complex decision tree, based on benchmarks. This one
300 // is already nice, but some cases can surely be optimized.
301 if (num_bits <= 16) {
302 if (num_bits <= 8) {
303 RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/1>(values);
304 } else {
305 RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/2>(values);
306 }
307 } else if (num_bits <= 32) { // num_bits ∈ [17..32]
308 if (values.size() < 1000) {
309 if (num_bits <= 24) {
310 RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/3>(values);
311 } else {
312 RadixSortTpl<T, /*radix_width=*/8, /*num_passes=*/4>(values);
313 }
314 } else if (values.size() < 2'500'000) {
315 if (num_bits <= 22) {
316 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/2>(values);
317 } else {
318 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/3>(values);
319 }
320 } else {
321 RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/2>(values);
322 }
323 } else if (num_bits <= 64) { // num_bits ∈ [33..64]
324 if (values.size() < 5000) {
325 absl::c_sort(values);
326 } else if (values.size() < 1'500'000) {
327 if (num_bits <= 33) {
328 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/3>(values);
329 } else if (num_bits <= 44) {
330 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/4>(values);
331 } else if (num_bits <= 55) {
332 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/5>(values);
333 } else {
334 RadixSortTpl<T, /*radix_width=*/11, /*num_passes=*/6>(values);
335 }
336 } else {
337 if (num_bits <= 48) {
338 RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/3>(values);
339 } else {
340 RadixSortTpl<T, /*radix_width=*/16, /*num_passes=*/4>(values);
341 }
342 }
343 } else {
344 LOG(DFATAL) << "RadixSort() called with unsupported value type";
345 absl::c_sort(values);
346 }
347}
348
349} // namespace operations_research
350
351#endif // OR_TOOLS_ALGORITHMS_RADIX_SORT_H_
const bool DEBUG_MODE
Definition macros.h:24
End of the interface. Below is the implementation.
typename ToUInt< T >::type to_uint
Definition radix_sort.h:96
In SWIG mode, we don't want anything besides these top-level includes.
int NumBitsForZeroTo(T max_value)
Definition radix_sort.h:253
const bool DEBUG_MODE
Definition radix_sort.h:266
void RadixSort(absl::Span< T > values, int num_bits=sizeof(T) *8)
Definition radix_sort.h:270
void RadixSortTpl(absl::Span< T > values)
The internal template that does all the work.
Definition radix_sort.h:101
trees with all degrees equal to