26#include "absl/container/inlined_vector.h"
27#include "absl/numeric/int128.h"
28#include "absl/strings/str_format.h"
29#include "absl/types/span.h"
38 return absl::StrFormat(
"[%d,%d]",
start,
end);
42 absl::Span<const ClosedInterval> intervals) {
43 for (
int i = 1; i < intervals.size(); ++i) {
44 if (intervals[i - 1].start > intervals[i - 1].
end)
return false;
46 if (intervals[i - 1].
end >= intervals[i].start ||
47 intervals[i - 1].
end + 1 >= intervals[i].start) {
51 return intervals.empty() ? true
52 : intervals.back().start <= intervals.back().end;
57template <
class Intervals>
58std::string IntervalsAsString(
const Intervals& intervals) {
60 for (ClosedInterval interval : intervals) {
61 result += interval.DebugString();
63 if (result.empty()) result =
"[]";
67void SortAndRemoveInvalidIntervals(
68 absl::InlinedVector<ClosedInterval, 1>* intervals) {
69 intervals->erase(std::remove_if(intervals->begin(), intervals->end(),
71 return interval.start > interval.end;
74 std::sort(intervals->begin(), intervals->end());
79void UnionOfSortedIntervals(absl::InlinedVector<ClosedInterval, 1>* intervals) {
80 DCHECK(std::is_sorted(intervals->begin(), intervals->end()));
81 const int size = intervals->size();
82 if (size == 0)
return;
85 for (
int i = 1;
i < size; ++
i) {
87 const int64_t
end = (*intervals)[new_size - 1].end;
88 if (
end == std::numeric_limits<int64_t>::max() ||
89 current.start <=
end + 1) {
90 (*intervals)[new_size - 1].end = std::max(current.end,
end);
93 (*intervals)[new_size++] = current;
95 intervals->resize(new_size);
99 intervals->shrink_to_fit();
106int64_t
CeilRatio(int64_t value, int64_t positive_coeff) {
107 DCHECK_GT(positive_coeff, 0);
108 const int64_t result = value / positive_coeff;
109 const int64_t adjust =
static_cast<int64_t
>(result * positive_coeff < value);
110 return result + adjust;
114 DCHECK_GT(positive_coeff, 0);
115 const int64_t result = value / positive_coeff;
116 const int64_t adjust =
static_cast<int64_t
>(result * positive_coeff > value);
117 return result - adjust;
125 const std::vector<ClosedInterval>& intervals) {
126 return out << IntervalsAsString(intervals);
130 return out << IntervalsAsString(domain);
142inline ClosedInterval UncheckedClosedInterval(int64_t s, int64_t e) {
151 : intervals_({UncheckedClosedInterval(left, right)}) {
152 if (left > right) intervals_.clear();
164 std::sort(values.begin(), values.end());
166 for (
const int64_t v : values) {
167 if (result.intervals_.empty() || v > result.intervals_.back().end + 1) {
168 result.intervals_.push_back({v, v});
170 result.intervals_.back().end = v;
172 if (result.intervals_.back().end == std::numeric_limits<int64_t>::max()) {
181 result.intervals_.assign(intervals.begin(), intervals.end());
182 SortAndRemoveInvalidIntervals(&result.intervals_);
183 UnionOfSortedIntervals(&result.intervals_);
188 absl::Span<const int64_t> flat_intervals) {
189 DCHECK(flat_intervals.size() % 2 == 0) << flat_intervals.size();
191 result.intervals_.reserve(flat_intervals.size() / 2);
192 for (
int i = 0; i < flat_intervals.size(); i += 2) {
193 result.intervals_.push_back({flat_intervals[i], flat_intervals[i + 1]});
195 SortAndRemoveInvalidIntervals(&result.intervals_);
196 UnionOfSortedIntervals(&result.intervals_);
205 const std::vector<std::vector<int64_t>>& intervals) {
207 for (
const std::vector<int64_t>& interval : intervals) {
208 if (interval.size() == 1) {
209 result.intervals_.push_back({interval[0], interval[0]});
211 DCHECK_EQ(interval.size(), 2);
212 result.intervals_.push_back({interval[0], interval[1]});
215 SortAndRemoveInvalidIntervals(&result.intervals_);
216 UnionOfSortedIntervals(&result.intervals_);
238 return intervals_.front().start;
243 return intervals_.back().end;
248 int64_t result =
Min();
250 if (interval.start <= 0 && interval.end >= 0)
return 0;
251 for (
const int64_t b : {interval.start, interval.end}) {
252 if (b > 0 && b <= std::abs(result)) result = b;
253 if (b < 0 && -b < std::abs(result)) result = b;
261 if (interval.start <= 0 && interval.end >= 0) {
262 return Domain(interval.start, interval.end);
271 if (wanted <= intervals_[0].start) {
272 return intervals_[0].start;
275 int64_t best_distance;
277 if (interval.start <= wanted && wanted <= interval.end) {
279 }
else if (interval.start >= wanted) {
280 return CapSub(interval.start, wanted) <= best_distance ? interval.start
283 best_point = interval.end;
284 best_distance =
CapSub(wanted, interval.end);
294 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
296 if (it == intervals_.begin())
return input;
305 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
307 if (it == intervals_.end())
return input;
308 const int64_t candidate = it->start;
309 if (it == intervals_.begin())
return candidate;
316 return intervals_.front().start;
323 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
325 if (it == intervals_.begin())
return false;
327 return value <= it->end;
332 int64_t min_distance = std::numeric_limits<int64_t>::max();
334 if (value >= interval.start && value <= interval.end)
return 0;
335 if (interval.start > value) {
336 min_distance = std::min(min_distance, interval.start - value);
339 min_distance = value - interval.end;
347 const auto& others = domain.intervals_;
350 for (; i < others.size() && interval.end > others[i].end; ++i) {
352 if (i == others.size())
return false;
353 if (interval.start < others[i].start)
return false;
359 const auto& a = intervals_;
360 const auto& b = domain.intervals_;
361 for (
int i = 0, j = 0; i < a.size() && j < b.size();) {
362 if (a[i].start <= b[j].start) {
363 if (a[i].
end < b[j].start) {
371 if (b[j].
end < a[i].start) {
384 result.intervals_.reserve(intervals_.size() + 1);
387 result.intervals_.push_back({next_start, interval.start - 1});
389 if (interval.end ==
kint64max)
return result;
390 next_start = interval.
end + 1;
392 result.intervals_.push_back({next_start,
kint64max});
399 result.NegateInPlace();
403void Domain::NegateInPlace() {
404 if (intervals_.empty())
return;
405 std::reverse(intervals_.begin(), intervals_.end());
406 if (intervals_.back().end ==
kint64min) {
408 intervals_.pop_back();
410 for (ClosedInterval& ref : intervals_) {
411 std::swap(ref.start, ref.end);
420 const auto& a = intervals_;
421 const auto& b = domain.intervals_;
422 for (
int i = 0, j = 0; i < a.size() && j < b.size();) {
423 if (a[i].start <= b[j].start) {
424 if (a[i].
end < b[j].start) {
430 if (a[i].
end <= b[j].
end) {
431 result.intervals_.push_back({b[j].start, a[i].end});
434 result.intervals_.push_back({b[j].start, b[j].end});
440 if (b[j].
end < a[i].start) {
443 if (b[j].
end <= a[i].
end) {
444 result.intervals_.push_back({a[i].start, b[j].end});
447 result.intervals_.push_back({a[i].start, a[i].end});
459 const auto& a = intervals_;
460 const auto& b = domain.intervals_;
461 result.intervals_.resize(a.size() + b.size());
462 std::merge(a.begin(), a.end(), b.begin(), b.end(), result.intervals_.begin());
463 UnionOfSortedIntervals(&result.intervals_);
471 const auto& a = intervals_;
472 const auto& b = domain.intervals_;
473 result.intervals_.reserve(a.size() * b.size());
476 if (i.start > 0 && j.start > 0) {
478 result.intervals_.push_back({i.start + j.start,
CapAdd(i.end, j.end)});
479 }
else if (i.end < 0 && j.end < 0) {
481 result.intervals_.push_back({
CapAdd(i.start, j.start), i.end + j.end});
483 result.intervals_.push_back(
490 if (a.size() > 1 && b.size() > 1) {
491 std::sort(result.intervals_.begin(), result.intervals_.end());
493 UnionOfSortedIntervals(&result.intervals_);
506 if (exact !=
nullptr) *exact =
true;
507 if (intervals_.empty())
return {};
508 if (coeff == 0)
return Domain(0);
510 const int64_t abs_coeff = std::abs(coeff);
511 const int64_t size_if_non_trivial = abs_coeff > 1 ?
Size() : 0;
512 if (size_if_non_trivial > kDomainComplexityLimit) {
513 if (exact !=
nullptr) *exact =
false;
519 const int64_t max_value =
kint64max / abs_coeff;
520 const int64_t min_value =
kint64min / abs_coeff;
521 result.intervals_.reserve(size_if_non_trivial);
523 for (int64_t v = i.start;; ++v) {
525 if (v >= min_value && v <= max_value) {
527 const int64_t new_value = v * abs_coeff;
528 result.intervals_.push_back({new_value, new_value});
532 if (v == i.end)
break;
538 if (coeff < 0) result.NegateInPlace();
544 const int64_t abs_coeff = std::abs(coeff);
546 i.start =
CapProd(i.start, abs_coeff);
547 i.end =
CapProd(i.end, abs_coeff);
549 UnionOfSortedIntervals(&result.intervals_);
550 if (coeff < 0) result.NegateInPlace();
563 new_interval.
start = std::min({a, b, c, d});
564 new_interval.
end = std::max({a, b, c, d});
565 result.intervals_.push_back(new_interval);
568 std::sort(result.intervals_.begin(), result.intervals_.end());
569 UnionOfSortedIntervals(&result.intervals_);
576 const int64_t abs_coeff = std::abs(coeff);
578 i.start = i.start / abs_coeff;
579 i.end = i.end / abs_coeff;
581 UnionOfSortedIntervals(&result.intervals_);
582 if (coeff < 0) result.NegateInPlace();
592 const int64_t abs_coeff = std::abs(coeff);
594 const int64_t start =
CeilRatio(i.start, abs_coeff);
596 if (start >
end)
continue;
597 if (new_size > 0 && start == result.intervals_[new_size - 1].end + 1) {
598 result.intervals_[new_size - 1].end =
end;
600 result.intervals_[new_size++] = {start,
end};
603 result.intervals_.resize(new_size);
604 result.intervals_.shrink_to_fit();
606 if (coeff < 0) result.NegateInPlace();
611Domain ModuloHelper(int64_t min, int64_t max,
const Domain& modulo) {
613 DCHECK_GT(modulo.
Min(), 0);
614 const int64_t max_mod = modulo.
Max() - 1;
618 if (modulo.
Min() == modulo.
Max()) {
619 const int64_t size = max - min;
620 const int64_t v1 = min % modulo.
Max();
621 if (v1 + size > max_mod)
return Domain(0, max_mod);
622 return Domain(v1, v1 + size);
626 return Domain(0, std::min(max, max_mod));
632 CHECK_GT(modulo.
Min(), 0);
633 const int64_t max_mod = modulo.
Max() - 1;
634 if (
Max() >= 0 &&
Min() <= 0) {
635 return Domain(std::max(
Min(), -max_mod), std::min(
Max(), max_mod));
638 return ModuloHelper(
Min(),
Max(), modulo);
641 return ModuloHelper(-
Max(), -
Min(), modulo).Negation();
646 CHECK_GT(divisor.
Min(), 0);
648 std::max(
Max() / divisor.
Min(),
Max() / divisor.
Max()));
656 {0, std::numeric_limits<int64_t>::max()}));
657 if (abs_domain.
Size() >= kDomainComplexityLimit) {
660 for (
const auto& interval : abs_domain) {
661 result.intervals_.push_back(
663 CapProd(interval.end, interval.end)));
665 UnionOfSortedIntervals(&result.intervals_);
668 std::vector<int64_t> values;
669 values.reserve(abs_domain.
Size());
670 for (
const int64_t value : abs_domain.
Values()) {
671 values.push_back(
CapProd(value, value));
678ClosedInterval EvaluateQuadraticProdInterval(int64_t a, int64_t b, int64_t c,
679 int64_t d, int64_t variable_min,
680 int64_t variable_max) {
691 const absl::int128 nominator =
692 -absl::int128{a} * absl::int128{d} - absl::int128{
b} * absl::int128{
c};
693 const absl::int128 denominator = absl::int128{a} * absl::int128{
c};
694 const absl::int128 evaluated_minimum_point = (nominator / denominator) / 2;
696 const auto& evaluate = [&a, &
b, &
c, &d](
const int64_t
x) {
700 const int64_t at_min_x = evaluate(variable_min);
701 const int64_t at_max_x = evaluate(variable_max);
702 int64_t min_var = std::min(at_min_x, at_max_x);
703 int64_t max_var = std::max(at_min_x, at_max_x);
705 if (evaluated_minimum_point > variable_min &&
706 evaluated_minimum_point < variable_max) {
707 const int64_t point_at_minimum_64 =
708 static_cast<int64_t
>(evaluated_minimum_point);
709 const int rounder = ((nominator > 0) == (denominator > 0) ? 1 : -1);
710 const int64_t point1 = evaluate(point_at_minimum_64);
711 const int64_t point2 = evaluate(point_at_minimum_64 + rounder);
712 min_var = std::min(min_var, std::min(point1, point2));
713 max_var = std::max(max_var, std::max(point1, point2));
724 if (
Size() < kDomainComplexityLimit) {
725 std::vector<int64_t> values;
726 values.reserve(
Size());
727 for (
const int64_t value :
Values()) {
744 for (
const auto& interval : intervals_) {
745 result.intervals_.push_back(EvaluateQuadraticProdInterval(
746 a, b, c, d, interval.start, interval.end));
748 std::sort(result.intervals_.begin(), result.intervals_.end());
749 UnionOfSortedIntervals(&result.intervals_);
759 if (implied_domain.
IsEmpty())
return result;
764 bool started =
false;
770 if (started && implied_domain.intervals_[i].start < interval.start) {
771 result.intervals_.push_back({min_point, max_point});
778 for (; i < implied_domain.intervals_.size(); ++i) {
780 if (current.
end >= interval.start && current.
start <= interval.end) {
782 const int64_t inter_max = std::min(interval.end, current.
end);
785 min_point = std::max(interval.start, current.
start);
786 max_point = inter_max;
790 DCHECK_GE(inter_max, max_point);
791 max_point = inter_max;
794 if (current.
end > interval.end)
break;
796 if (i == implied_domain.intervals_.size())
break;
799 result.intervals_.push_back({min_point, max_point});
806 std::vector<int64_t> result;
808 result.push_back(interval.start);
809 result.push_back(interval.end);
815 const auto& d1 = intervals_;
816 const auto& d2 = other.intervals_;
817 const int common_size = std::min(d1.size(), d2.size());
818 for (
int i = 0; i < common_size; ++i) {
823 if (i1.
end < i2.
end)
return true;
824 if (i1.
end > i2.
end)
return false;
826 return d1.size() < d2.size();
832 int64_t current_sum = 0.0;
833 int current_index = 0;
835 if (current_index >= k)
break;
836 for (
int v(interval.start); v <= interval.end; ++v) {
837 if (current_index >= k)
break;
852 const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends) {
857 const std::vector<int>& starts,
const std::vector<int>& ends) {
862 const std::vector<ClosedInterval>& intervals) {
872 int64_t next_start = start;
875 const int64_t next_end =
CapSub(interval.
start, 1);
876 if (next_end >
end)
break;
877 if (next_start <= next_end) {
882 if (next_start <=
end) {
885 return interval_list;
889 int64_t start, int64_t
end) {
894 return intervals_.
end();
897 auto result = intervals_.insert({start,
end});
898 if (!result.second)
return result.first;
909 auto it1 = result.first;
911 it1 = intervals_.begin();
913 const int64_t before_start = start - 1;
914 while (it1 != intervals_.begin()) {
917 if (prev_it->end < before_start)
break;
924 auto it2 = result.first;
926 it2 = intervals_.end();
928 const int64_t after_end =
end + 1;
931 }
while (it2 != intervals_.end() && it2->start <= after_end);
939 if (it1 == it3)
return it3;
940 const int64_t new_start = std::min(it1->start, start);
941 const int64_t new_end = std::max(it3->end,
end);
942 auto it = intervals_.erase(it1, it3);
953void SortedDisjointIntervalList::InsertAll(
const std::vector<T>& starts,
954 const std::vector<T>& ends) {
955 CHECK_EQ(starts.size(), ends.size());
956 for (
int i = 0; i < starts.size(); ++i)
InsertInterval(starts[i], ends[i]);
960 const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends) {
961 InsertAll(starts, ends);
965 const std::vector<int>& ends) {
967 InsertAll(starts, ends);
972 const auto it = intervals_.upper_bound({value,
kint64max});
973 if (it ==
begin())
return it;
976 DCHECK_LE(it_prev->start, value);
977 return it_prev->end >= value ? it_prev : it;
982 const auto it = intervals_.upper_bound({value,
kint64max});
992 str += interval.DebugString();
Domain SimplifyUsingImpliedDomain(const Domain &implied_domain) const
static Domain FromVectorIntervals(const std::vector< std::vector< int64_t > > &intervals)
bool OverlapsWith(const Domain &domain) const
Domain MultiplicationBy(int64_t coeff, bool *exact=nullptr) const
static Domain FromValues(std::vector< int64_t > values)
bool operator<(const Domain &other) const
int64_t ValueAtOrAfter(int64_t input) const
std::vector< int64_t > FlattenedIntervals() const
Domain SquareSuperset() const
Domain IntersectionWith(const Domain &domain) const
static Domain LowerOrEqual(int64_t value)
Domain QuadraticSuperset(int64_t a, int64_t b, int64_t c, int64_t d) const
static Domain FromIntervals(absl::Span< const ClosedInterval > intervals)
Domain ContinuousMultiplicationBy(int64_t coeff) const
bool Contains(int64_t value) const
DomainIteratorBeginEnd Values() const &
Domain PositiveModuloBySuperset(const Domain &modulo) const
Domain AdditionWith(const Domain &domain) const
bool IsIncludedIn(const Domain &domain) const
static Domain FromFlatIntervals(const std::vector< int64_t > &flat_intervals)
int64_t SmallestValue() const
Domain()
By default, Domain will be empty.
int64_t FixedValue() const
Domain DivisionBy(int64_t coeff) const
int64_t Distance(int64_t value) const
std::string ToString() const
static Domain AllValues()
static Domain FromFlatSpanOfIntervals(absl::Span< const int64_t > flat_intervals)
absl::InlinedVector< ClosedInterval, 1 >::const_iterator end() const
Domain PositiveDivisionBySuperset(const Domain &divisor) const
Domain RelaxIfTooComplex() const
static Domain GreaterOrEqual(int64_t value)
Domain UnionWith(const Domain &domain) const
Domain Complement() const
Domain PartAroundZero() const
Domain InverseMultiplicationBy(int64_t coeff) const
int64_t ValueAtOrBefore(int64_t input) const
int64_t ClosestValue(int64_t wanted) const
void InsertIntervals(const std::vector< int64_t > &starts, const std::vector< int64_t > &ends)
SortedDisjointIntervalList()
Iterator LastIntervalLessOrEqual(int64_t value) const
std::string DebugString() const
Iterator InsertInterval(int64_t start, int64_t end)
ConstIterator begin() const
SortedDisjointIntervalList BuildComplementOnInterval(int64_t start, int64_t end)
IntervalSet::iterator Iterator
Iterator FirstIntervalGreaterOrEqual(int64_t value) const
ConstIterator end() const
int64_t CapAdd(int64_t x, int64_t y)
int64_t FloorRatio(int64_t value, int64_t positive_coeff)
int64_t CapSub(int64_t x, int64_t y)
int64_t CeilRatio(int64_t value, int64_t positive_coeff)
ClosedInterval::Iterator end(ClosedInterval interval)
int64_t SumOfKMinValueInDomain(const Domain &domain, int k)
std::ostream & operator<<(std::ostream &out, const Assignment &assignment)
bool AddOverflows(int64_t x, int64_t y)
int64_t CapProd(int64_t x, int64_t y)
bool IntervalsAreSortedAndNonAdjacent(absl::Span< const ClosedInterval > intervals)
int64_t SumOfKMaxValueInDomain(const Domain &domain, int k)
static int input(yyscan_t yyscanner)
std::string DebugString() const
static const int64_t kint64max
static const int64_t kint64min