22#include "absl/container/flat_hash_map.h"
23#include "absl/container/flat_hash_set.h"
24#include "absl/container/inlined_vector.h"
25#include "absl/log/check.h"
26#include "absl/types/span.h"
54 absl::flat_hash_map<int, int> var_to_position;
58 std::vector<std::optional<int>> position_mapping(num_exprs, std::nullopt);
59 int num_shared_vars = 0;
60 int num_fixed_exprs = 0;
61 for (
int i = 0;
i < num_exprs; ++
i) {
68 const int var = expr.
vars(0);
69 const auto [it, inserted] =
70 var_to_position.insert({var, var_to_position.size()});
73 position_mapping[
i] = it->second;
77 const int num_kept_exprs = num_exprs - num_shared_vars - num_fixed_exprs;
79 std::vector<std::vector<int64_t>> new_tuples;
80 new_tuples.reserve(num_tuples);
82 std::vector<int64_t> new_scaled_values;
83 new_scaled_values.reserve(num_kept_exprs);
85 for (
int t = 0; t < num_tuples; ++t) {
86 bool tuple_is_valid =
true;
87 new_scaled_values.clear();
89 for (
int e = 0; e < num_exprs; ++e) {
90 const int64_t value = ct->
table().
values(t * num_exprs + e);
94 tuple_is_valid =
false;
97 }
else if (position_mapping[e].has_value()) {
98 const int var_first_position = position_mapping[e].value();
99 const int64_t var_value = new_scaled_values[var_first_position];
101 if (value != forced_value) {
102 tuple_is_valid =
false;
107 tuple_is_valid =
false;
114 if (tuple_is_valid) {
115 DCHECK_EQ(new_scaled_values.size(), num_kept_exprs);
116 new_tuples.push_back(new_scaled_values);
121 for (
int e = 0; e < num_exprs; ++e) {
122 if (position_mapping[e].has_value())
continue;
129 if (num_kept_exprs < num_exprs) {
131 for (
int e = 0; e < num_exprs; ++e) {
132 if (position_mapping[e].has_value())
continue;
136 CHECK_EQ(index, num_kept_exprs);
143 if (new_tuples.size() < num_tuples) {
147 if (num_kept_exprs == 0) {
153 const bool all_tuples_invalid = new_tuples.empty();
154 const bool is_trivially_sat = all_tuples_invalid == ct->
table().
negated();
162 if (new_tuples.empty()) {
177 for (
const std::vector<int64_t>& tuple : new_tuples) {
183 std::vector<std::vector<int64_t>>* tuples) {
184 if (tuples->empty())
return;
189 const int num_vars = (*tuples)[0].size();
191 std::vector<int> to_remove;
192 std::vector<int64_t> tuple_minus_var_i(num_vars - 1);
193 for (
int i = 0;
i < num_vars; ++
i) {
194 const int64_t domain_size = domain_sizes[
i];
195 if (domain_size == 1)
continue;
196 absl::flat_hash_map<std::vector<int64_t>, std::vector<int>>
197 masked_tuples_to_indices;
198 for (
int t = 0; t < tuples->size(); ++t) {
200 for (
int j = 0; j < num_vars; ++j) {
201 if (
i == j)
continue;
202 tuple_minus_var_i[out++] = (*tuples)[t][j];
204 masked_tuples_to_indices[tuple_minus_var_i].push_back(t);
207 for (
const auto& it : masked_tuples_to_indices) {
208 if (it.second.size() != domain_size)
continue;
210 to_remove.insert(to_remove.end(), it.second.begin() + 1, it.second.end());
212 std::sort(to_remove.begin(), to_remove.end(), std::greater<int>());
213 for (
const int t : to_remove) {
214 (*tuples)[t] = tuples->back();
228void FullyCompressTuplesRecursive(
229 absl::Span<const int64_t> domain_sizes,
230 absl::Span<std::vector<int64_t>> tuples,
231 std::vector<absl::InlinedVector<int64_t, 2>>* reversed_suffix,
232 std::vector<std::vector<absl::InlinedVector<int64_t, 2>>>* output) {
234 absl::InlinedVector<int64_t, 2> values;
237 bool operator<(
const TempData& other)
const {
238 return values < other.values;
241 std::vector<TempData> temp_data;
243 CHECK(!tuples.empty());
244 CHECK(!tuples[0].empty());
245 const int64_t domain_size = domain_sizes[tuples[0].size() - 1];
248 std::sort(tuples.begin(), tuples.end());
249 for (
int i = 0;
i < tuples.size();) {
251 temp_data.push_back({{tuples[start].back()}, start});
252 tuples[start].pop_back();
253 for (++
i;
i < tuples.size(); ++
i) {
254 const int64_t v = tuples[
i].back();
255 tuples[
i].pop_back();
256 if (tuples[
i] == tuples[start]) {
257 temp_data.back().values.push_back(v);
259 tuples[
i].push_back(v);
266 for (
const int64_t v : temp_data.back().values) {
268 temp_data.back().values.clear();
276 if (temp_data.back().values.size() == domain_size) {
277 temp_data.back().values.clear();
281 if (temp_data.size() == 1) {
282 output->push_back({});
283 for (
const int64_t v : tuples[temp_data[0].index]) {
285 output->back().push_back({});
287 output->back().push_back({v});
290 output->back().push_back(temp_data[0].values);
291 for (
int i = reversed_suffix->size(); --
i >= 0;) {
292 output->back().push_back((*reversed_suffix)[
i]);
299 std::sort(temp_data.begin(), temp_data.end());
300 std::vector<std::vector<int64_t>> temp_tuples;
301 for (
int i = 0;
i < temp_data.size();) {
302 reversed_suffix->push_back(temp_data[
i].values);
305 for (;
i < temp_data.size();
i++) {
306 if (temp_data[start].values != temp_data[
i].values)
break;
307 temp_tuples.push_back(tuples[temp_data[
i].index]);
309 FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(temp_tuples),
310 reversed_suffix, output);
311 reversed_suffix->pop_back();
322 absl::Span<const int64_t> domain_sizes,
323 std::vector<std::vector<int64_t>>* tuples) {
324 std::vector<absl::InlinedVector<int64_t, 2>> reversed_suffix;
325 std::vector<std::vector<absl::InlinedVector<int64_t, 2>>> output;
326 FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(*tuples),
327 &reversed_suffix, &output);
336 std::vector<absl::flat_hash_set<int64_t>>* states,
337 std::vector<absl::flat_hash_set<int64_t>>* labels) {
339 const absl::flat_hash_set<int64_t> final_states(
345 states->resize(n + 1);
349 for (
int time = 0; time < n; ++time) {
354 if (!(*states)[time].contains(tail))
continue;
356 if (time == n - 1 && !final_states.contains(head))
continue;
357 (*labels)[time].insert(label);
358 (*states)[time + 1].insert(head);
363 for (
int time = n - 1; time >= 0; --time) {
364 absl::flat_hash_set<int64_t> new_states;
365 absl::flat_hash_set<int64_t> new_labels;
371 if (!(*states)[time].contains(tail))
continue;
372 if (!(*labels)[time].contains(label))
continue;
373 if (!(*states)[time + 1].contains(head))
continue;
374 new_labels.insert(label);
375 new_states.insert(tail);
377 (*labels)[time].swap(new_labels);
378 (*states)[time].swap(new_states);
::int64_t transition_label(int index) const
::int64_t final_states(int index) const
const ::operations_research::sat::LinearExpressionProto & exprs(int index) const
::int64_t transition_tail(int index) const
::int64_t starting_state() const
int exprs_size() const
repeated .operations_research.sat.LinearExpressionProto exprs = 8;
int transition_tail_size() const
repeated int64 transition_tail = 4;
::int64_t transition_head(int index) const
const ::operations_research::sat::TableConstraintProto & table() const
::operations_research::sat::TableConstraintProto *PROTOBUF_NONNULL mutable_table()
::int32_t vars(int index) const
void set_offset(::int64_t value)
void set_coeffs(int index, ::int64_t value)
int coeffs_size() const
repeated int64 coeffs = 2;
int64_t FixedValue(int ref) const
bool IsFixed(int ref) const
void UpdateRuleStats(const std::string &name, int num_times=1)
bool DomainContains(int ref, int64_t value) const
bool ModelIsUnsat() const
::operations_research::sat::LinearExpressionProto *PROTOBUF_NONNULL mutable_exprs(int index)
::int64_t values(int index) const
::google::protobuf::RepeatedField<::int64_t > *PROTOBUF_NONNULL mutable_values()
::int32_t vars(int index) const
const ::operations_research::sat::LinearExpressionProto & exprs(int index) const
::operations_research::sat::LinearExpressionProto *PROTOBUF_NONNULL add_exprs()
int values_size() const
repeated int64 values = 2;
int exprs_size() const
repeated .operations_research.sat.LinearExpressionProto exprs = 4;
void set_negated(bool value)
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
constexpr int64_t kTableAnyValue
void CompressTuples(absl::Span< const int64_t > domain_sizes, std::vector< std::vector< int64_t > > *tuples)
std::vector< std::vector< absl::InlinedVector< int64_t, 2 > > > FullyCompressTuples(absl::Span< const int64_t > domain_sizes, std::vector< std::vector< int64_t > > *tuples)
int64_t GetInnerVarValue(const LinearExpressionProto &expr, int64_t value)
int64_t AffineExpressionValueAt(const LinearExpressionProto &expr, int64_t value)
Evaluates an affine expression at the given value.
void PropagateAutomaton(const AutomatonConstraintProto &proto, const PresolveContext &context, std::vector< absl::flat_hash_set< int64_t > > *states, std::vector< absl::flat_hash_set< int64_t > > *labels)
Fills and propagates the set of reachable states/labels.
void CanonicalizeTable(PresolveContext *context, ConstraintProto *ct)
In SWIG mode, we don't want anything besides these top-level includes.