26#include "absl/container/inlined_vector.h"
27#include "absl/strings/str_format.h"
28#include "absl/types/span.h"
37 return absl::StrFormat(
"[%d,%d]",
start,
end);
41 absl::Span<const ClosedInterval> intervals) {
42 for (
int i = 1; i < intervals.size(); ++i) {
43 if (intervals[i - 1].
start > intervals[i - 1].
end)
return false;
45 if (intervals[i - 1].
end >= intervals[i].
start ||
46 intervals[i - 1].
end + 1 >= intervals[i].
start) {
50 return intervals.empty() ? true
51 : intervals.back().start <= intervals.back().end;
56template <
class Intervals>
57std::string IntervalsAsString(
const Intervals& intervals) {
59 for (ClosedInterval
interval : intervals) {
62 if (result.empty()) result =
"[]";
68void UnionOfSortedIntervals(absl::InlinedVector<ClosedInterval, 1>* intervals) {
69 DCHECK(std::is_sorted(intervals->begin(), intervals->end()));
70 const int size = intervals->size();
71 if (
size == 0)
return;
74 for (
int i = 1;
i <
size; ++
i) {
75 const ClosedInterval& current = (*intervals)[
i];
76 const int64_t
end = (*intervals)[new_size - 1].end;
77 if (
end == std::numeric_limits<int64_t>::max() ||
78 current.start <=
end + 1) {
79 (*intervals)[new_size - 1].end = std::max(current.end,
end);
82 (*intervals)[new_size++] = current;
84 intervals->resize(new_size);
88 intervals->shrink_to_fit();
96 DCHECK_GT(positive_coeff, 0);
97 const int64_t result =
value / positive_coeff;
98 const int64_t adjust =
static_cast<int64_t
>(result * positive_coeff <
value);
99 return result + adjust;
103 DCHECK_GT(positive_coeff, 0);
104 const int64_t result =
value / positive_coeff;
105 const int64_t adjust =
static_cast<int64_t
>(result * positive_coeff >
value);
106 return result - adjust;
114 const std::vector<ClosedInterval>& intervals) {
115 return out << IntervalsAsString(intervals);
119 return out << IntervalsAsString(domain);
131inline ClosedInterval UncheckedClosedInterval(int64_t s, int64_t e) {
140 : intervals_({UncheckedClosedInterval(left, right)}) {
141 if (left > right) intervals_.clear();
147 std::sort(values.begin(), values.end());
149 for (
const int64_t v : values) {
150 if (result.intervals_.empty() || v > result.intervals_.back().end + 1) {
151 result.intervals_.push_back({v, v});
153 result.intervals_.back().end = v;
155 if (result.intervals_.back().end == std::numeric_limits<int64_t>::max()) {
164 result.intervals_.assign(intervals.begin(), intervals.end());
165 std::sort(result.intervals_.begin(), result.intervals_.end());
166 UnionOfSortedIntervals(&result.intervals_);
171 absl::Span<const int64_t> flat_intervals) {
172 DCHECK(flat_intervals.size() % 2 == 0) << flat_intervals.size();
174 result.intervals_.reserve(flat_intervals.size() / 2);
175 for (
int i = 0; i < flat_intervals.size(); i += 2) {
176 result.intervals_.push_back({flat_intervals[i], flat_intervals[i + 1]});
178 std::sort(result.intervals_.begin(), result.intervals_.end());
179 UnionOfSortedIntervals(&result.intervals_);
188 const std::vector<std::vector<int64_t>>& intervals) {
190 for (
const std::vector<int64_t>&
interval : intervals) {
198 std::sort(result.intervals_.begin(), result.intervals_.end());
199 UnionOfSortedIntervals(&result.intervals_);
221 return intervals_.front().start;
226 return intervals_.back().end;
231 int64_t result =
Min();
235 if (
b > 0 &&
b <= std::abs(result)) result =
b;
236 if (
b < 0 && -
b < std::abs(result)) result =
b;
254 if (wanted <= intervals_[0].
start) {
255 return intervals_[0].start;
258 int64_t best_distance;
262 }
else if (
interval.start >= wanted) {
277 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
279 if (it == intervals_.begin())
return input;
288 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
290 if (it == intervals_.end())
return input;
291 const int64_t candidate = it->start;
292 if (it == intervals_.begin())
return candidate;
299 return intervals_.front().start;
306 auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
308 if (it == intervals_.begin())
return false;
310 return value <= it->end;
315 int64_t min_distance = std::numeric_limits<int64_t>::max();
319 min_distance = std::min(min_distance,
interval.start -
value);
330 const auto& others = domain.intervals_;
333 for (; i < others.size() &&
interval.end > others[i].end; ++i) {
335 if (i == others.size())
return false;
336 if (
interval.start < others[i].start)
return false;
344 result.intervals_.reserve(intervals_.size() + 1);
347 result.intervals_.push_back({next_start,
interval.start - 1});
352 result.intervals_.push_back({next_start,
kint64max});
359 result.NegateInPlace();
363void Domain::NegateInPlace() {
364 if (intervals_.empty())
return;
365 std::reverse(intervals_.begin(), intervals_.end());
366 if (intervals_.back().end ==
kint64min) {
368 intervals_.pop_back();
370 for (ClosedInterval& ref : intervals_) {
371 std::swap(ref.start, ref.end);
380 const auto&
a = intervals_;
381 const auto&
b = domain.intervals_;
382 for (
int i = 0, j = 0; i <
a.size() && j <
b.size();) {
391 result.intervals_.push_back({
b[j].start,
a[i].end});
394 result.intervals_.push_back({
b[j].start,
b[j].end});
404 result.intervals_.push_back({
a[i].start,
b[j].end});
407 result.intervals_.push_back({
a[i].start,
a[i].end});
419 const auto&
a = intervals_;
420 const auto&
b = domain.intervals_;
421 result.intervals_.resize(
a.size() +
b.size());
422 std::merge(
a.begin(),
a.end(),
b.begin(),
b.end(), result.intervals_.begin());
423 UnionOfSortedIntervals(&result.intervals_);
431 const auto&
a = intervals_;
432 const auto&
b = domain.intervals_;
433 result.intervals_.reserve(
a.size() *
b.size());
436 if (i.start > 0 && j.start > 0) {
438 result.intervals_.push_back({i.start + j.start,
CapAdd(i.end, j.end)});
439 }
else if (i.end < 0 && j.end < 0) {
441 result.intervals_.push_back({
CapAdd(i.start, j.start), i.end + j.end});
443 result.intervals_.push_back(
450 if (
a.size() > 1 &&
b.size() > 1) {
451 std::sort(result.intervals_.begin(), result.intervals_.end());
453 UnionOfSortedIntervals(&result.intervals_);
466 if (exact !=
nullptr) *exact =
true;
467 if (intervals_.empty())
return {};
468 if (coeff == 0)
return Domain(0);
470 const int64_t abs_coeff = std::abs(coeff);
471 const int64_t size_if_non_trivial = abs_coeff > 1 ?
Size() : 0;
472 if (size_if_non_trivial > kDomainComplexityLimit) {
473 if (exact !=
nullptr) *exact =
false;
479 const int64_t max_value =
kint64max / abs_coeff;
480 const int64_t min_value =
kint64min / abs_coeff;
481 result.intervals_.reserve(size_if_non_trivial);
483 for (int64_t v = i.start;; ++v) {
485 if (v >= min_value && v <= max_value) {
487 const int64_t new_value = v * abs_coeff;
488 result.intervals_.push_back({new_value, new_value});
492 if (v == i.end)
break;
498 if (coeff < 0) result.NegateInPlace();
504 const int64_t abs_coeff = std::abs(coeff);
506 i.start =
CapProd(i.start, abs_coeff);
507 i.end =
CapProd(i.end, abs_coeff);
509 UnionOfSortedIntervals(&result.intervals_);
510 if (coeff < 0) result.NegateInPlace();
523 new_interval.
start = std::min({
a,
b, c, d});
524 new_interval.
end = std::max({
a,
b, c, d});
525 result.intervals_.push_back(new_interval);
528 std::sort(result.intervals_.begin(), result.intervals_.end());
529 UnionOfSortedIntervals(&result.intervals_);
536 const int64_t abs_coeff = std::abs(coeff);
538 i.start = i.start / abs_coeff;
539 i.end = i.end / abs_coeff;
541 UnionOfSortedIntervals(&result.intervals_);
542 if (coeff < 0) result.NegateInPlace();
552 const int64_t abs_coeff = std::abs(coeff);
557 if (new_size > 0 &&
start == result.intervals_[new_size - 1].end + 1) {
558 result.intervals_[new_size - 1].end =
end;
560 result.intervals_[new_size++] = {
start,
end};
563 result.intervals_.resize(new_size);
564 result.intervals_.shrink_to_fit();
566 if (coeff < 0) result.NegateInPlace();
573 DCHECK_GT(modulo.
Min(), 0);
574 const int64_t max_mod = modulo.
Max() - 1;
578 if (modulo.
Min() == modulo.
Max()) {
580 const int64_t v1 =
min % modulo.
Max();
581 if (v1 +
size > max_mod)
return Domain(0, max_mod);
586 return Domain(0, std::min(
max, max_mod));
592 CHECK_GT(modulo.
Min(), 0);
593 const int64_t max_mod = modulo.
Max() - 1;
594 if (
Max() >= 0 &&
Min() <= 0) {
595 return Domain(std::max(
Min(), -max_mod), std::min(
Max(), max_mod));
598 return ModuloHelper(
Min(),
Max(), modulo);
601 return ModuloHelper(-
Max(), -
Min(), modulo).Negation();
606 CHECK_GT(divisor.
Min(), 0);
608 std::max(
Max() / divisor.
Min(),
Max() / divisor.
Max()));
616 {0, std::numeric_limits<int64_t>::max()}));
617 if (abs_domain.
Size() >= kDomainComplexityLimit) {
620 for (
const auto&
interval : abs_domain) {
621 result.intervals_.push_back(
625 UnionOfSortedIntervals(&result.intervals_);
628 std::vector<int64_t> values;
629 values.reserve(abs_domain.
Size());
643 if (implied_domain.
IsEmpty())
return result;
648 bool started =
false;
654 if (started && implied_domain.intervals_[i].start <
interval.start) {
655 result.intervals_.push_back({min_point, max_point});
662 for (; i < implied_domain.intervals_.size(); ++i) {
666 const int64_t inter_max = std::min(
interval.end, current.
end);
670 max_point = inter_max;
674 DCHECK_GE(inter_max, max_point);
675 max_point = inter_max;
680 if (i == implied_domain.intervals_.size())
break;
683 result.intervals_.push_back({min_point, max_point});
690 std::vector<int64_t> result;
699 const auto& d1 = intervals_;
700 const auto& d2 = other.intervals_;
701 const int common_size = std::min(d1.size(), d2.size());
702 for (
int i = 0; i < common_size; ++i) {
707 if (i1.
end < i2.
end)
return true;
708 if (i1.
end > i2.
end)
return false;
710 return d1.size() < d2.size();
716 int64_t current_sum = 0.0;
717 int current_index = 0;
719 if (current_index >= k)
break;
721 if (current_index >= k)
break;
736 const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends) {
741 const std::vector<int>& starts,
const std::vector<int>& ends) {
746 const std::vector<ClosedInterval>& intervals) {
756 int64_t next_start =
start;
760 if (next_end >
end)
break;
761 if (next_start <= next_end) {
766 if (next_start <=
end) {
769 return interval_list;
778 return intervals_.end();
781 auto result = intervals_.insert({
start,
end});
782 if (!result.second)
return result.first;
793 auto it1 = result.first;
795 it1 = intervals_.begin();
797 const int64_t before_start =
start - 1;
798 while (it1 != intervals_.begin()) {
801 if (prev_it->end < before_start)
break;
808 auto it2 = result.first;
810 it2 = intervals_.end();
812 const int64_t after_end =
end + 1;
815 }
while (it2 != intervals_.end() && it2->start <= after_end);
823 if (it1 == it3)
return it3;
824 const int64_t new_start = std::min(it1->start,
start);
825 const int64_t new_end = std::max(it3->end,
end);
826 auto it = intervals_.erase(it1, it3);
837 int64_t
value, int64_t* newly_covered) {
846 *newly_covered =
value;
847 if (it ==
end() || it->start !=
value + 1) {
855 DCHECK_EQ(it->start,
value + 1);
864 CHECK_NE(
kint64max, it_prev->end) <<
"Cannot grow right by one: the interval "
865 "that would grow already ends at "
867 *newly_covered = it_prev->end + 1;
868 if (it !=
end() && it_prev->end + 2 == it->start) {
871 intervals_.erase(it);
879void SortedDisjointIntervalList::InsertAll(
const std::vector<T>& starts,
880 const std::vector<T>& ends) {
881 CHECK_EQ(starts.size(), ends.size());
882 for (
int i = 0; i < starts.size(); ++i)
InsertInterval(starts[i], ends[i]);
886 const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends) {
887 InsertAll(starts, ends);
891 const std::vector<int>& ends) {
893 InsertAll(starts, ends);
899 if (it ==
begin())
return it;
902 DCHECK_LE(it_prev->start,
value);
903 return it_prev->end >=
value ? it_prev : it;
Domain SimplifyUsingImpliedDomain(const Domain &implied_domain) const
static Domain FromVectorIntervals(const std::vector< std::vector< int64_t > > &intervals)
Domain MultiplicationBy(int64_t coeff, bool *exact=nullptr) const
static Domain FromValues(std::vector< int64_t > values)
bool operator<(const Domain &other) const
int64_t ValueAtOrAfter(int64_t input) const
std::vector< int64_t > FlattenedIntervals() const
Domain SquareSuperset() const
Domain IntersectionWith(const Domain &domain) const
static Domain FromIntervals(absl::Span< const ClosedInterval > intervals)
Domain ContinuousMultiplicationBy(int64_t coeff) const
bool Contains(int64_t value) const
DomainIteratorBeginEnd Values() const &
Domain PositiveModuloBySuperset(const Domain &modulo) const
Domain AdditionWith(const Domain &domain) const
bool IsIncludedIn(const Domain &domain) const
static Domain FromFlatIntervals(const std::vector< int64_t > &flat_intervals)
int64_t SmallestValue() const
Domain()
By default, Domain will be empty.
int64_t FixedValue() const
Domain DivisionBy(int64_t coeff) const
int64_t Distance(int64_t value) const
std::string ToString() const
static Domain AllValues()
static Domain FromFlatSpanOfIntervals(absl::Span< const int64_t > flat_intervals)
absl::InlinedVector< ClosedInterval, 1 >::const_iterator end() const
Domain PositiveDivisionBySuperset(const Domain &divisor) const
Domain RelaxIfTooComplex() const
Domain UnionWith(const Domain &domain) const
Domain Complement() const
Domain PartAroundZero() const
Domain InverseMultiplicationBy(int64_t coeff) const
int64_t ValueAtOrBefore(int64_t input) const
int64_t ClosestValue(int64_t wanted) const
std::string DebugString() const override
void InsertIntervals(const std::vector< int64_t > &starts, const std::vector< int64_t > &ends)
SortedDisjointIntervalList()
Iterator LastIntervalLessOrEqual(int64_t value) const
std::string DebugString() const
IntervalSet::iterator Iterator
Iterator InsertInterval(int64_t start, int64_t end)
ConstIterator begin() const
SortedDisjointIntervalList BuildComplementOnInterval(int64_t start, int64_t end)
Iterator FirstIntervalGreaterOrEqual(int64_t value) const
ConstIterator end() const
Iterator GrowRightByOne(int64_t value, int64_t *newly_covered)
In SWIG mode, we don't want anything besides these top-level includes.
int64_t CapAdd(int64_t x, int64_t y)
int64_t FloorRatio(int64_t value, int64_t positive_coeff)
int64_t CapSub(int64_t x, int64_t y)
int64_t CeilRatio(int64_t value, int64_t positive_coeff)
int64_t SumOfKMinValueInDomain(const Domain &domain, int k)
Returns the sum of smallest k values in the domain.
std::ostream & operator<<(std::ostream &out, const Assignment &assignment)
bool AddOverflows(int64_t x, int64_t y)
int64_t CapProd(int64_t x, int64_t y)
bool IntervalsAreSortedAndNonAdjacent(absl::Span< const ClosedInterval > intervals)
int64_t SumOfKMaxValueInDomain(const Domain &domain, int k)
Returns the sum of largest k values in the domain.
static int input(yyscan_t yyscanner)
std::optional< int64_t > end
std::string DebugString() const
static const int64_t kint64max
static const int64_t kint64min