26#include "absl/container/inlined_vector.h"
27#include "absl/numeric/int128.h"
28#include "absl/strings/str_format.h"
29#include "absl/types/span.h"
38 return absl::StrFormat(
"[%d,%d]",
start,
end);
42 absl::Span<const ClosedInterval> intervals) {
43 for (
int i = 1; i < intervals.size(); ++i) {
44 if (intervals[i - 1].start > intervals[i - 1].end)
return false;
46 if (intervals[i - 1].end >= intervals[i].start ||
47 intervals[i - 1].end + 1 >= intervals[i].start) {
51 return intervals.empty() ? true
52 : intervals.back().start <= intervals.back().end;
57template <
class Intervals>
58std::string IntervalsAsString(
const Intervals& intervals) {
60 for (ClosedInterval interval : intervals) {
61 result += interval.DebugString();
63 if (result.empty()) result =
"[]";
67void SortAndRemoveInvalidIntervals(
68 absl::InlinedVector<ClosedInterval, 1>* intervals) {
69 intervals->erase(std::remove_if(intervals->begin(), intervals->end(),
71 return interval.start > interval.end;
74 std::sort(intervals->begin(), intervals->end());
79void UnionOfSortedIntervals(absl::InlinedVector<ClosedInterval, 1>* intervals) {
80 DCHECK(std::is_sorted(intervals->begin(), intervals->end()));
81 const int size = intervals->size();
82 if (size == 0)
return;
85 for (
int i = 1;
i < size; ++
i) {
87 const int64_t end = (*intervals)[new_size - 1].
end;
88 if (end == std::numeric_limits<int64_t>::max() ||
89 current.start <= end + 1) {
90 (*intervals)[new_size - 1].end = std::max(current.end, end);
93 (*intervals)[new_size++] = current;
95 intervals->resize(new_size);
99 intervals->shrink_to_fit();
106int64_t
CeilRatio(int64_t value, int64_t positive_coeff) {
107 DCHECK_GT(positive_coeff, 0);
108 const int64_t result = value / positive_coeff;
109 const int64_t adjust =
static_cast<int64_t
>(result * positive_coeff < value);
110 return result + adjust;
114 DCHECK_GT(positive_coeff, 0);
115 const int64_t result = value / positive_coeff;
116 const int64_t adjust =
static_cast<int64_t
>(result * positive_coeff > value);
117 return result - adjust;
125 const std::vector<ClosedInterval>& intervals) {
126 return out << IntervalsAsString(intervals);
130 return out << IntervalsAsString(domain);
142inline ClosedInterval UncheckedClosedInterval(int64_t s, int64_t e) {
151 : intervals_({UncheckedClosedInterval(left, right)}) {
152 if (left > right) intervals_.clear();
158 std::sort(values.begin(), values.end());
160 for (
const int64_t v : values) {
161 if (result.intervals_.empty() || v > result.intervals_.back().end + 1) {
162 result.intervals_.push_back({v, v});
164 result.intervals_.back().end = v;
166 if (result.intervals_.back().end == std::numeric_limits<int64_t>::max()) {
175 result.intervals_.assign(intervals.begin(), intervals.end());
176 SortAndRemoveInvalidIntervals(&result.intervals_);
177 UnionOfSortedIntervals(&result.intervals_);
182 absl::Span<const int64_t> flat_intervals) {
183 DCHECK(flat_intervals.size() % 2 == 0) << flat_intervals.size();
185 result.intervals_.reserve(flat_intervals.size() / 2);
186 for (
int i = 0; i < flat_intervals.size(); i += 2) {
187 result.intervals_.push_back({flat_intervals[i], flat_intervals[i + 1]});
189 SortAndRemoveInvalidIntervals(&result.intervals_);
190 UnionOfSortedIntervals(&result.intervals_);
199 const std::vector<std::vector<int64_t>>& intervals) {
201 for (
const std::vector<int64_t>& interval : intervals) {
202 if (interval.size() == 1) {
203 result.intervals_.push_back({interval[0], interval[0]});
205 DCHECK_EQ(interval.size(), 2);
206 result.intervals_.push_back({interval[0], interval[1]});
209 SortAndRemoveInvalidIntervals(&result.intervals_);
210 UnionOfSortedIntervals(&result.intervals_);
232 return intervals_.front().start;
237 return intervals_.back().end;
242 int64_t result =
Min();
244 if (interval.start <= 0 && interval.end >= 0)
return 0;
245 for (
const int64_t b : {interval.start, interval.end}) {
246 if (b > 0 && b <= std::abs(result)) result = b;
247 if (b < 0 && -b < std::abs(result)) result = b;
255 if (interval.start <= 0 && interval.end >= 0) {
256 return Domain(interval.start, interval.end);
265 if (wanted <= intervals_[0].start) {
266 return intervals_[0].start;
269 int64_t best_distance;
271 if (interval.start <= wanted && wanted <= interval.end) {
273 }
else if (interval.start >= wanted) {
274 return CapSub(interval.start, wanted) <= best_distance ? interval.start
277 best_point = interval.end;
278 best_distance =
CapSub(wanted, interval.end);
288 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
290 if (it == intervals_.begin())
return input;
299 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
301 if (it == intervals_.end())
return input;
302 const int64_t candidate = it->start;
303 if (it == intervals_.begin())
return candidate;
310 return intervals_.front().start;
317 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
319 if (it == intervals_.begin())
return false;
321 return value <= it->end;
326 int64_t min_distance = std::numeric_limits<int64_t>::max();
328 if (value >= interval.start && value <= interval.end)
return 0;
329 if (interval.start > value) {
330 min_distance = std::min(min_distance, interval.start - value);
333 min_distance = value - interval.end;
341 const auto& others = domain.intervals_;
344 for (; i < others.size() && interval.end > others[i].end; ++i) {
346 if (i == others.size())
return false;
347 if (interval.start < others[i].start)
return false;
355 result.intervals_.reserve(intervals_.size() + 1);
358 result.intervals_.push_back({next_start, interval.start - 1});
360 if (interval.end ==
kint64max)
return result;
361 next_start = interval.
end + 1;
363 result.intervals_.push_back({next_start,
kint64max});
370 result.NegateInPlace();
374void Domain::NegateInPlace() {
375 if (intervals_.empty())
return;
376 std::reverse(intervals_.begin(), intervals_.end());
377 if (intervals_.back().end ==
kint64min) {
379 intervals_.pop_back();
381 for (ClosedInterval& ref : intervals_) {
382 std::swap(ref.start, ref.end);
391 const auto& a = intervals_;
392 const auto& b = domain.intervals_;
393 for (
int i = 0, j = 0; i < a.size() && j < b.size();) {
394 if (a[i].start <= b[j].start) {
395 if (a[i].
end < b[j].start) {
401 if (a[i].
end <= b[j].
end) {
402 result.intervals_.push_back({b[j].start, a[i].end});
405 result.intervals_.push_back({b[j].start, b[j].end});
411 if (b[j].
end < a[i].start) {
414 if (b[j].
end <= a[i].
end) {
415 result.intervals_.push_back({a[i].start, b[j].end});
418 result.intervals_.push_back({a[i].start, a[i].end});
430 const auto& a = intervals_;
431 const auto& b = domain.intervals_;
432 result.intervals_.resize(a.size() + b.size());
433 std::merge(a.begin(), a.end(), b.begin(), b.end(), result.intervals_.begin());
434 UnionOfSortedIntervals(&result.intervals_);
442 const auto& a = intervals_;
443 const auto& b = domain.intervals_;
444 result.intervals_.reserve(a.size() * b.size());
447 if (i.start > 0 && j.start > 0) {
449 result.intervals_.push_back({i.start + j.start,
CapAdd(i.end, j.end)});
450 }
else if (i.end < 0 && j.end < 0) {
452 result.intervals_.push_back({
CapAdd(i.start, j.start), i.end + j.end});
454 result.intervals_.push_back(
461 if (a.size() > 1 && b.size() > 1) {
462 std::sort(result.intervals_.begin(), result.intervals_.end());
464 UnionOfSortedIntervals(&result.intervals_);
477 if (exact !=
nullptr) *exact =
true;
478 if (intervals_.empty())
return {};
479 if (coeff == 0)
return Domain(0);
481 const int64_t abs_coeff = std::abs(coeff);
482 const int64_t size_if_non_trivial = abs_coeff > 1 ?
Size() : 0;
483 if (size_if_non_trivial > kDomainComplexityLimit) {
484 if (exact !=
nullptr) *exact =
false;
490 const int64_t max_value =
kint64max / abs_coeff;
491 const int64_t min_value =
kint64min / abs_coeff;
492 result.intervals_.reserve(size_if_non_trivial);
494 for (int64_t v = i.start;; ++v) {
496 if (v >= min_value && v <= max_value) {
498 const int64_t new_value = v * abs_coeff;
499 result.intervals_.push_back({new_value, new_value});
503 if (v == i.end)
break;
509 if (coeff < 0) result.NegateInPlace();
515 const int64_t abs_coeff = std::abs(coeff);
517 i.start =
CapProd(i.start, abs_coeff);
518 i.end =
CapProd(i.end, abs_coeff);
520 UnionOfSortedIntervals(&result.intervals_);
521 if (coeff < 0) result.NegateInPlace();
534 new_interval.
start = std::min({a, b, c, d});
535 new_interval.
end = std::max({a, b, c, d});
536 result.intervals_.push_back(new_interval);
539 std::sort(result.intervals_.begin(), result.intervals_.end());
540 UnionOfSortedIntervals(&result.intervals_);
547 const int64_t abs_coeff = std::abs(coeff);
549 i.start = i.start / abs_coeff;
550 i.end = i.end / abs_coeff;
552 UnionOfSortedIntervals(&result.intervals_);
553 if (coeff < 0) result.NegateInPlace();
563 const int64_t abs_coeff = std::abs(coeff);
565 const int64_t start =
CeilRatio(i.start, abs_coeff);
567 if (start >
end)
continue;
568 if (new_size > 0 && start == result.intervals_[new_size - 1].end + 1) {
569 result.intervals_[new_size - 1].end =
end;
571 result.intervals_[new_size++] = {start,
end};
574 result.intervals_.resize(new_size);
575 result.intervals_.shrink_to_fit();
577 if (coeff < 0) result.NegateInPlace();
582Domain ModuloHelper(int64_t min, int64_t max,
const Domain& modulo) {
584 DCHECK_GT(modulo.
Min(), 0);
585 const int64_t max_mod = modulo.
Max() - 1;
589 if (modulo.
Min() == modulo.
Max()) {
590 const int64_t size = max - min;
591 const int64_t v1 = min % modulo.
Max();
592 if (v1 + size > max_mod)
return Domain(0, max_mod);
593 return Domain(v1, v1 + size);
597 return Domain(0, std::min(max, max_mod));
603 CHECK_GT(modulo.
Min(), 0);
604 const int64_t max_mod = modulo.
Max() - 1;
605 if (
Max() >= 0 &&
Min() <= 0) {
606 return Domain(std::max(
Min(), -max_mod), std::min(
Max(), max_mod));
609 return ModuloHelper(
Min(),
Max(), modulo);
612 return ModuloHelper(-
Max(), -
Min(), modulo).Negation();
617 CHECK_GT(divisor.
Min(), 0);
619 std::max(
Max() / divisor.
Min(),
Max() / divisor.
Max()));
627 {0, std::numeric_limits<int64_t>::max()}));
628 if (abs_domain.
Size() >= kDomainComplexityLimit) {
631 for (
const auto& interval : abs_domain) {
632 result.intervals_.push_back(
634 CapProd(interval.end, interval.end)));
636 UnionOfSortedIntervals(&result.intervals_);
639 std::vector<int64_t> values;
640 values.reserve(abs_domain.
Size());
641 for (
const int64_t value : abs_domain.
Values()) {
642 values.push_back(
CapProd(value, value));
649ClosedInterval EvaluateQuadraticProdInterval(int64_t a, int64_t b, int64_t c,
650 int64_t d, int64_t variable_min,
651 int64_t variable_max) {
662 const absl::int128 nominator =
663 -absl::int128{a} * absl::int128{d} - absl::int128{
b} * absl::int128{
c};
664 const absl::int128 denominator = absl::int128{a} * absl::int128{
c};
665 const absl::int128 evaluated_minimum_point = (nominator / denominator) / 2;
667 const auto& evaluate = [&a, &
b, &
c, &d](
const int64_t
x) {
671 const int64_t at_min_x = evaluate(variable_min);
672 const int64_t at_max_x = evaluate(variable_max);
673 int64_t min_var = std::min(at_min_x, at_max_x);
674 int64_t max_var = std::max(at_min_x, at_max_x);
676 if (evaluated_minimum_point > variable_min &&
677 evaluated_minimum_point < variable_max) {
678 const int64_t point_at_minimum_64 =
679 static_cast<int64_t
>(evaluated_minimum_point);
680 const int rounder = ((nominator > 0) == (denominator > 0) ? 1 : -1);
681 const int64_t point1 = evaluate(point_at_minimum_64);
682 const int64_t point2 = evaluate(point_at_minimum_64 + rounder);
683 min_var = std::min(min_var, std::min(point1, point2));
684 max_var = std::max(max_var, std::max(point1, point2));
695 if (
Size() < kDomainComplexityLimit) {
696 std::vector<int64_t> values;
697 values.reserve(
Size());
698 for (
const int64_t value :
Values()) {
715 for (
const auto& interval : intervals_) {
716 result.intervals_.push_back(EvaluateQuadraticProdInterval(
717 a, b, c, d, interval.start, interval.end));
719 std::sort(result.intervals_.begin(), result.intervals_.end());
720 UnionOfSortedIntervals(&result.intervals_);
730 if (implied_domain.
IsEmpty())
return result;
735 bool started =
false;
741 if (started && implied_domain.intervals_[i].start < interval.start) {
742 result.intervals_.push_back({min_point, max_point});
749 for (; i < implied_domain.intervals_.size(); ++i) {
751 if (current.
end >= interval.start && current.
start <= interval.end) {
753 const int64_t inter_max = std::min(interval.end, current.
end);
756 min_point = std::max(interval.start, current.
start);
757 max_point = inter_max;
761 DCHECK_GE(inter_max, max_point);
762 max_point = inter_max;
765 if (current.
end > interval.end)
break;
767 if (i == implied_domain.intervals_.size())
break;
770 result.intervals_.push_back({min_point, max_point});
777 std::vector<int64_t> result;
779 result.push_back(interval.start);
780 result.push_back(interval.end);
786 const auto& d1 = intervals_;
787 const auto& d2 = other.intervals_;
788 const int common_size = std::min(d1.size(), d2.size());
789 for (
int i = 0; i < common_size; ++i) {
794 if (i1.
end < i2.
end)
return true;
795 if (i1.
end > i2.
end)
return false;
797 return d1.size() < d2.size();
803 int64_t current_sum = 0.0;
804 int current_index = 0;
806 if (current_index >= k)
break;
807 for (
int v(interval.start); v <= interval.end; ++v) {
808 if (current_index >= k)
break;
823 const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends) {
828 const std::vector<int>& starts,
const std::vector<int>& ends) {
833 const std::vector<ClosedInterval>& intervals) {
843 int64_t next_start = start;
846 const int64_t next_end =
CapSub(interval.
start, 1);
847 if (next_end >
end)
break;
848 if (next_start <= next_end) {
853 if (next_start <=
end) {
856 return interval_list;
860 int64_t start, int64_t
end) {
865 return intervals_.
end();
868 auto result = intervals_.insert({start,
end});
869 if (!result.second)
return result.first;
880 auto it1 = result.first;
882 it1 = intervals_.begin();
884 const int64_t before_start = start - 1;
885 while (it1 != intervals_.begin()) {
888 if (prev_it->end < before_start)
break;
895 auto it2 = result.first;
897 it2 = intervals_.end();
899 const int64_t after_end =
end + 1;
902 }
while (it2 != intervals_.end() && it2->start <= after_end);
910 if (it1 == it3)
return it3;
911 const int64_t new_start = std::min(it1->start, start);
912 const int64_t new_end = std::max(it3->end,
end);
913 auto it = intervals_.erase(it1, it3);
924 int64_t value, int64_t* newly_covered) {
925 auto it = intervals_.upper_bound({value,
kint64max});
932 if (it ==
begin() || ((value !=
kint64min) && it_prev->end < value - 1)) {
933 *newly_covered = value;
934 if (it ==
end() || it->start != value + 1) {
936 return intervals_.insert(it, {value, value});
942 DCHECK_EQ(it->start, value + 1);
951 CHECK_NE(
kint64max, it_prev->end) <<
"Cannot grow right by one: the interval "
952 "that would grow already ends at "
954 *newly_covered = it_prev->end + 1;
955 if (it !=
end() && it_prev->end + 2 == it->start) {
958 intervals_.erase(it);
966void SortedDisjointIntervalList::InsertAll(
const std::vector<T>& starts,
967 const std::vector<T>& ends) {
968 CHECK_EQ(starts.size(), ends.size());
969 for (
int i = 0; i < starts.size(); ++i)
InsertInterval(starts[i], ends[i]);
973 const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends) {
974 InsertAll(starts, ends);
978 const std::vector<int>& ends) {
980 InsertAll(starts, ends);
985 const auto it = intervals_.upper_bound({value,
kint64max});
986 if (it ==
begin())
return it;
989 DCHECK_LE(it_prev->start, value);
990 return it_prev->end >= value ? it_prev : it;
995 const auto it = intervals_.upper_bound({value,
kint64max});
1005 str += interval.DebugString();
Domain SimplifyUsingImpliedDomain(const Domain &implied_domain) const
static Domain FromVectorIntervals(const std::vector< std::vector< int64_t > > &intervals)
Domain MultiplicationBy(int64_t coeff, bool *exact=nullptr) const
static Domain FromValues(std::vector< int64_t > values)
bool operator<(const Domain &other) const
int64_t ValueAtOrAfter(int64_t input) const
std::vector< int64_t > FlattenedIntervals() const
Domain SquareSuperset() const
Domain IntersectionWith(const Domain &domain) const
Domain QuadraticSuperset(int64_t a, int64_t b, int64_t c, int64_t d) const
static Domain FromIntervals(absl::Span< const ClosedInterval > intervals)
Domain ContinuousMultiplicationBy(int64_t coeff) const
bool Contains(int64_t value) const
DomainIteratorBeginEnd Values() const &
Domain PositiveModuloBySuperset(const Domain &modulo) const
Domain AdditionWith(const Domain &domain) const
bool IsIncludedIn(const Domain &domain) const
static Domain FromFlatIntervals(const std::vector< int64_t > &flat_intervals)
int64_t SmallestValue() const
Domain()
By default, Domain will be empty.
int64_t FixedValue() const
Domain DivisionBy(int64_t coeff) const
int64_t Distance(int64_t value) const
std::string ToString() const
static Domain AllValues()
static Domain FromFlatSpanOfIntervals(absl::Span< const int64_t > flat_intervals)
absl::InlinedVector< ClosedInterval, 1 >::const_iterator end() const
Domain PositiveDivisionBySuperset(const Domain &divisor) const
Domain RelaxIfTooComplex() const
Domain UnionWith(const Domain &domain) const
Domain Complement() const
Domain PartAroundZero() const
Domain InverseMultiplicationBy(int64_t coeff) const
int64_t ValueAtOrBefore(int64_t input) const
int64_t ClosestValue(int64_t wanted) const
void InsertIntervals(const std::vector< int64_t > &starts, const std::vector< int64_t > &ends)
SortedDisjointIntervalList()
Iterator LastIntervalLessOrEqual(int64_t value) const
std::string DebugString() const
Iterator InsertInterval(int64_t start, int64_t end)
ConstIterator begin() const
SortedDisjointIntervalList BuildComplementOnInterval(int64_t start, int64_t end)
IntervalSet::iterator Iterator
Iterator FirstIntervalGreaterOrEqual(int64_t value) const
ConstIterator end() const
Iterator GrowRightByOne(int64_t value, int64_t *newly_covered)
In SWIG mode, we don't want anything besides these top-level includes.
int64_t CapAdd(int64_t x, int64_t y)
int64_t FloorRatio(int64_t value, int64_t positive_coeff)
int64_t CapSub(int64_t x, int64_t y)
int64_t CeilRatio(int64_t value, int64_t positive_coeff)
int64_t SumOfKMinValueInDomain(const Domain &domain, int k)
Returns the sum of smallest k values in the domain.
std::ostream & operator<<(std::ostream &out, const Assignment &assignment)
bool AddOverflows(int64_t x, int64_t y)
int64_t CapProd(int64_t x, int64_t y)
bool IntervalsAreSortedAndNonAdjacent(absl::Span< const ClosedInterval > intervals)
int64_t SumOfKMaxValueInDomain(const Domain &domain, int k)
Returns the sum of largest k values in the domain.
static int input(yyscan_t yyscanner)
std::string DebugString() const
static const int64_t kint64max
static const int64_t kint64min