Google OR-Tools v9.12
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
n_choose_k.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
15
16#include <cmath>
17#include <cstdint>
18#include <limits>
19#include <vector>
20
21#include "absl/log/check.h"
22#include "absl/numeric/int128.h"
23#include "absl/status/status.h"
24#include "absl/status/statusor.h"
25#include "absl/strings/str_format.h"
26#include "absl/time/clock.h"
27#include "absl/time/time.h"
31
32namespace operations_research {
33namespace {
34// This is the actual computation. It's in O(k).
35template <typename Int>
36Int InternalChoose(Int n, Int k) {
37 DCHECK_LE(k, n - k);
38 DCHECK_GT(k, 0); // Having k>0 lets us start with i=2 (small optimization).
39 // We compute n * (n-1) * ... * (n-k+1) / k! in the best possible order to
40 // guarantee exact results, while trying to avoid overflows. It's not
41 // perfect: we finish with a division by k, which means that me may overflow
42 // even if the result doesn't (by a factor of up to k).
43 Int result = n;
44 for (Int i = 2; i <= k; ++i) {
45 result *= n + 1 - i;
46 result /= i; // The product of i consecutive numbers is divisible by i!.
47 }
48 return result;
49}
50
51// This function precomputes the maximum N such that (N choose K) doesn't
52// overflow, for all K.
53// When `overflows_intermediate_computation` is true, "overflow" means
54// "some overflow happens inside InternalChoose<int64_t>()", and when it's false
55// it simply means "the result doesn't fit in an int64_t".
56// This is only used in contexts where K ≤ N-K, which implies N ≥ 2K, thus we
57// can stop when (2K Choose K) overflows, because at and beyond such K,
58// (N Choose K) will always overflow. In practice that happens for K=31 or 34
59// depending on `overflows_intermediate_computation`.
60template <class Int>
61std::vector<Int> LastNThatDoesNotOverflowForAllK(
62 bool overflows_intermediate_computation) {
63 absl::Time start_time = absl::Now();
64 // Given the algorithm used in InternalChoose(), it's not hard to
65 // find out when (N choose K) overflows an int64_t during its internal
66 // computation: that's when (N choose K) > MAX_INT / k.
67
68 // For K ≤ 2, we hardcode the values of the maximum N. That's because
69 // the binary search done below uses MathUtil::LogCombinations, which only
70 // works on int32_t, and that's problematic for the max N we get for K=2.
71 //
72 // For K=2, we want N(N-1) ≤ 2^num_digits, or N(N-1)/2 ≤ 2^num_digits if
73 // !overflows_intermediate_computation, i.e. N(N-1) ≤ 2^(num_digits+1).
74 // Then, when d is even, N(N-1) ≤ 2^d ⇔ N ≤ 2^(d/2), which is simple.
75 // When d is odd, it's harder: N(N-1)≈(N-0.5)² and thus we get the bound
76 // N ≤ pow(2.0, d/2)+0.5.
77 const int bound_digits = std::numeric_limits<Int>::digits +
78 (overflows_intermediate_computation ? 0 : 1);
79 std::vector<Int> result = {
80 std::numeric_limits<Int>::max(), // K=0
81 std::numeric_limits<Int>::max(), // K=1
82 bound_digits % 2 == 0
83 ? Int{1} << (bound_digits / 2)
84 : static_cast<Int>(
85 0.5 + std::pow(2.0, 0.5 * std::numeric_limits<Int>::digits)),
86 };
87 // We find the last N with binary search, for all K. We stop growing K
88 // when (2*K Choose K) overflows.
89 for (Int k = 3;; ++k) {
90 const double max_log_comb =
91 overflows_intermediate_computation
92 ? std::numeric_limits<Int>::digits * std::log(2) - std::log(k)
93 : std::numeric_limits<Int>::digits * std::log(2);
94 result.push_back(BinarySearch<Int>(
95 /*x_true*/ k,
96 // x_false=X, X needs to be large enough so that X choose 3 overflows:
97 // (X choose 3)≈(X-1)³/6, so we pick X = 2+6*2^(num_digits/3+1).
98 /*x_false=*/
99 (static_cast<Int>(
100 2 + 6 * std::pow(2.0, std::numeric_limits<Int>::digits / 3 + 1))),
101 [k, max_log_comb](Int n) {
102 return MathUtil::LogCombinations(n, k) <= max_log_comb;
103 }));
104 if (result.back() < 2 * k) {
105 result.pop_back();
106 break;
107 }
108 }
109 // Some DCHECKs for int64_t, which should validate the general formulaes.
110 if constexpr (std::numeric_limits<Int>::digits == 63) {
111 DCHECK_EQ(result.size(),
112 overflows_intermediate_computation
113 ? 31 // 60 Choose 30 < 2^63/30 but 62 Choose 31 > 2^63/31.
114 : 34); // 66 Choose 33 < 2^63 but 68 Choose 34 > 2^63.
115 }
116 VLOG(1) << "LastNThatDoesNotOverflowForAllK(): " << absl::Now() - start_time;
117 return result;
118}
119
120template <typename Int>
121bool NChooseKIntermediateComputationOverflowsInt(Int n, Int k) {
122 DCHECK_LE(k, n - k);
123 static const auto* const result =
124 new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>(
125 /*overflows_intermediate_computation=*/true));
126 return k < result->size() ? n > (*result)[k] : true;
127}
128
129template <typename Int>
130bool NChooseKResultOverflowsInt(Int n, Int k) {
131 DCHECK_LE(k, n - k);
132 static const auto* const result =
133 new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>(
134 /*overflows_intermediate_computation=*/false));
135 return k < result->size() ? n > (*result)[k] : true;
136}
137} // namespace
138
139// NOTE(user): If performance ever matters, we could simply precompute and
140// store all (N choose K) that don't overflow, there aren't that many of them:
141// only a few tens of thousands, after removing simple cases like k ≤ 5.
142absl::StatusOr<int64_t> NChooseK(int64_t n, int64_t k) {
143 if (n < 0) {
144 return absl::InvalidArgumentError(absl::StrFormat("n is negative (%d)", n));
145 }
146 if (k < 0) {
147 return absl::InvalidArgumentError(absl::StrFormat("k is negative (%d)", k));
148 }
149 if (k > n / 2) {
150 if (k > n) return 0; // No way to choose more than n elements from n.
151 k = n - k;
152 }
153 if (k == 0) return 1;
154 if (n < std::numeric_limits<uint32_t>::max() &&
155 !NChooseKIntermediateComputationOverflowsInt<uint32_t>(n, k)) {
156 return static_cast<int64_t>(InternalChoose<uint32_t>(n, k));
157 }
158 if (!NChooseKIntermediateComputationOverflowsInt<int64_t>(n, k)) {
159 return InternalChoose<uint64_t>(n, k);
160 }
161 if (NChooseKResultOverflowsInt<int64_t>(n, k)) {
162 return absl::InvalidArgumentError(
163 absl::StrFormat("(%d choose %d) overflows int64", n, k));
164 }
165 return static_cast<int64_t>(InternalChoose<absl::uint128>(n, k));
166}
167
168} // namespace operations_research
static double LogCombinations(int n, int k)
Definition mathutil.cc:33
In SWIG mode, we don't want anything besides these top-level includes.
absl::StatusOr< int64_t > NChooseK(int64_t n, int64_t k)
Point BinarySearch(Point x_true, Point x_false, std::function< bool(Point)> f)