109 absl::BitGenRef random_;
110 std::vector<Index> equivalent_choices_;
131 HeapElement() =
default;
139 double operator<(
const HeapElement& other)
const {
140 return value > other.value;
144 std::vector<HeapElement> tops_;
147 struct QueryStats :
public StatsGroup {
150 get_maximum(
"get_maximum", this),
151 heap_size_on_hit(
"heap_size_on_hit", this),
152 random_choices(
"random_choices", this) {}
153 TimeDistribution get_maximum;
154 IntegerDistribution heap_size_on_hit;
155 IntegerDistribution random_choices;
160template <
typename Index>
165 is_candidate_.ClearAndResize(n);
168template <
typename Index>
170 is_candidate_.Clear(position);
173template <
typename Index>
180template <
typename Index>
183 DCHECK(!std::isnan(
value));
184 DCHECK(tops_.empty());
185 is_candidate_.Set(position);
186 values_[position] =
value;
189template <
typename Index>
192 DCHECK(!std::isnan(
value));
193 is_candidate_.Set(position);
194 values_[position] =
value;
195 if (
value >= threshold_) UpdateTopK(position,
value);
198template <
typename Index>
200 if (equivalent_choices_.empty())
return best;
201 equivalent_choices_.push_back(best);
202 stats_.random_choices.Add(equivalent_choices_.size());
204 return equivalent_choices_[std::uniform_int_distribution<int>(
205 0, equivalent_choices_.size() - 1)(random_)];
208template <
typename Index>
212 Index best_position(-1);
213 equivalent_choices_.clear();
223 if (!tops_.empty()) {
225 for (
const HeapElement e : tops_) {
227 if (!is_candidate_[e.index])
continue;
228 if (values_[e.index] != e.value)
continue;
230 tops_[new_size++] = e;
231 if (e.value >= best_value) {
232 if (e.value == best_value) {
233 equivalent_choices_.push_back(e.index);
236 equivalent_choices_.clear();
237 best_value = e.value;
238 best_position = e.index;
241 tops_.resize(new_size);
243 stats_.heap_size_on_hit.Add(new_size);
244 return RandomizeIfManyChoices(best_position);
250 DCHECK(tops_.empty());
251 const auto values = values_.const_view();
252 for (
const Index position : is_candidate_) {
257 if (
value < threshold_)
continue;
258 UpdateTopK(position,
value);
260 if (
value >= best_value) {
261 if (
value == best_value) {
262 equivalent_choices_.push_back(position);
265 equivalent_choices_.clear();
267 best_position = position;
271 return RandomizeIfManyChoices(best_position);
274template <
typename Index>
278 DCHECK_GE(
value, threshold_);
284 constexpr int k = 31;
285 static_assert(((k + 1) & k) == 0,
"k + 1 should be a power of 2.");
288 if (tops_.size() < k) {
289 tops_.emplace_back(position,
value);
290 if (tops_.size() == k) {
291 std::make_heap(tops_.begin(), tops_.end());
292 threshold_ = tops_[0].value;
307 if (absl::Bernoulli(random_, 0.5)) {
308 tops_[0].index = position;
317 std::pop_heap(tops_.begin(), tops_.end());
318 tops_.back() = HeapElement(position,
value);
319 std::push_heap(tops_.begin(), tops_.end());
320 threshold_ = tops_[0].value;
328 DCHECK_EQ(tops_.size(), k);
329 constexpr int limit = k / 2;
331 const int left_child = 2 *
i + 1;
333 const Fractional l_value = tops_[left_child].value;
335 if (l_value > r_value) {
336 if (
value <= r_value)
break;
340 if (
value <= l_value)
break;
341 tops_[
i] = tops_[left_child];
345 tops_[
i] = HeapElement(position,
value);
346 threshold_ = tops_[0].value;
347 DCHECK(std::is_heap(tops_.begin(), tops_.end()));