Google OR-Tools v9.12
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
cp_model_table.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
15
16#include <algorithm>
17#include <cstdint>
18#include <functional>
19#include <optional>
20#include <vector>
21
22#include "absl/container/flat_hash_map.h"
23#include "absl/container/flat_hash_set.h"
24#include "absl/container/inlined_vector.h"
25#include "absl/log/check.h"
26#include "absl/types/span.h"
30
31namespace operations_research {
32namespace sat {
33
34void CanonicalizeTable(PresolveContext* context, ConstraintProto* ct) {
35 if (context->ModelIsUnsat()) return;
36
37 DCHECK(ct->table().vars().empty());
38 if (ct->table().exprs().empty()) {
39 CHECK(ct->table().values().empty());
40 return;
41 }
42
43 if (ct->table().values().empty()) {
44 // Make the trivial table constraint canonical.
45 ct->mutable_table()->clear_exprs();
46 ct->mutable_table()->add_exprs()->set_offset(0);
47 return;
48 }
49
50 const int num_exprs = ct->table().exprs_size();
51 const int num_tuples = ct->table().values_size() / num_exprs;
52
53 // Detect expressions sharing the same variable as a previous expression.
54 absl::flat_hash_map<int, int> var_to_position;
55
56 // The mapping between the position in the original list of expressions, and
57 // the position in the reduced list of expressions.
58 std::vector<std::optional<int>> position_mapping(num_exprs, std::nullopt);
59 int num_shared_vars = 0;
60 int num_fixed_exprs = 0;
61 for (int i = 0; i < num_exprs; ++i) {
62 const LinearExpressionProto& expr = ct->table().exprs(i);
63 if (context->IsFixed(expr)) {
64 ++num_fixed_exprs;
65 continue;
66 }
67
68 const int var = expr.vars(0);
69 const auto [it, inserted] =
70 var_to_position.insert({var, var_to_position.size()});
71 if (!inserted) {
72 ++num_shared_vars;
73 position_mapping[i] = it->second;
74 }
75 }
76
77 const int num_kept_exprs = num_exprs - num_shared_vars - num_fixed_exprs;
78
79 std::vector<std::vector<int64_t>> new_tuples;
80 new_tuples.reserve(num_tuples);
81
82 std::vector<int64_t> new_scaled_values;
83 new_scaled_values.reserve(num_kept_exprs);
84
85 for (int t = 0; t < num_tuples; ++t) {
86 bool tuple_is_valid = true;
87 new_scaled_values.clear();
88
89 for (int e = 0; e < num_exprs; ++e) {
90 const int64_t value = ct->table().values(t * num_exprs + e);
91 const LinearExpressionProto& expr = ct->table().exprs(e);
92 if (context->IsFixed(expr)) {
93 if (value != context->FixedValue(expr)) {
94 tuple_is_valid = false;
95 break;
96 }
97 } else if (position_mapping[e].has_value()) {
98 const int var_first_position = position_mapping[e].value();
99 const int64_t var_value = new_scaled_values[var_first_position];
100 const int64_t forced_value = AffineExpressionValueAt(expr, var_value);
101 if (value != forced_value) {
102 tuple_is_valid = false;
103 break;
104 }
105 } else {
106 if (!context->DomainContains(expr, value)) {
107 tuple_is_valid = false;
108 break;
109 }
110 new_scaled_values.push_back(GetInnerVarValue(expr, value));
111 }
112 }
113
114 if (tuple_is_valid) {
115 DCHECK_EQ(new_scaled_values.size(), num_kept_exprs);
116 new_tuples.push_back(new_scaled_values);
117 }
118 }
119
120 // Remove all scaling on expressions as we have stored the inner values.
121 for (int e = 0; e < num_exprs; ++e) {
122 if (position_mapping[e].has_value()) continue;
123 if (context->IsFixed(ct->table().exprs(e))) continue;
124 DCHECK_EQ(ct->table().exprs(e).coeffs_size(), 1);
125 ct->mutable_table()->mutable_exprs(e)->set_offset(0);
126 ct->mutable_table()->mutable_exprs(e)->set_coeffs(0, 1);
127 }
128
129 if (num_kept_exprs < num_exprs) {
130 int index = 0;
131 for (int e = 0; e < num_exprs; ++e) {
132 if (position_mapping[e].has_value()) continue;
133 if (context->IsFixed(ct->table().exprs(e))) continue;
134 ct->mutable_table()->mutable_exprs()->SwapElements(index++, e);
135 }
136 CHECK_EQ(index, num_kept_exprs);
137 ct->mutable_table()->mutable_exprs()->DeleteSubrange(index,
138 num_exprs - index);
139 context->UpdateRuleStats("table: remove expressions");
140 }
141
143 if (new_tuples.size() < num_tuples) {
144 context->UpdateRuleStats("table: remove tuples");
145 }
146
147 if (num_kept_exprs == 0) {
148 // The table was not empty from the beginning (we test it), but it became
149 // empty after removing all fixed variables. So either we also remove all
150 // the tuples, in which case there was no tuple that matched, or some tuple
151 // (of size 0!) remained and in this case we did find a match.
152 context->UpdateRuleStats("table: all constant");
153 const bool all_tuples_invalid = new_tuples.empty();
154 const bool is_trivially_sat = all_tuples_invalid == ct->table().negated();
155 ct->mutable_table()->clear_exprs();
156 ct->mutable_table()->clear_values();
157 ct->mutable_table()->add_exprs()->set_offset(0);
158 ct->mutable_table()->set_negated(is_trivially_sat);
159 return;
160 }
161
162 if (new_tuples.empty()) {
163 // Add a trivially unsat (or trivially sat if negated) table constraint so
164 // code downstream can handle any eventual enforcement literals.
165 context->UpdateRuleStats("table: all tuples invalid");
166 ct->mutable_table()->clear_exprs();
167 ct->mutable_table()->clear_values();
168 if (!ct->table().negated()) {
169 ct->mutable_table()->add_exprs()->set_offset(0);
170 }
171 ct->mutable_table()->set_negated(false);
172 return;
173 }
174
175 // Write sorted tuples.
176 ct->mutable_table()->clear_values();
177 for (const std::vector<int64_t>& tuple : new_tuples) {
178 ct->mutable_table()->mutable_values()->Add(tuple.begin(), tuple.end());
179 }
180}
181
182void CompressTuples(absl::Span<const int64_t> domain_sizes,
183 std::vector<std::vector<int64_t>>* tuples) {
184 if (tuples->empty()) return;
185
186 // Remove duplicates if any.
188
189 const int num_vars = (*tuples)[0].size();
190
191 std::vector<int> to_remove;
192 std::vector<int64_t> tuple_minus_var_i(num_vars - 1);
193 for (int i = 0; i < num_vars; ++i) {
194 const int64_t domain_size = domain_sizes[i];
195 if (domain_size == 1) continue;
196 absl::flat_hash_map<std::vector<int64_t>, std::vector<int>>
197 masked_tuples_to_indices;
198 for (int t = 0; t < tuples->size(); ++t) {
199 int out = 0;
200 for (int j = 0; j < num_vars; ++j) {
201 if (i == j) continue;
202 tuple_minus_var_i[out++] = (*tuples)[t][j];
203 }
204 masked_tuples_to_indices[tuple_minus_var_i].push_back(t);
205 }
206 to_remove.clear();
207 for (const auto& it : masked_tuples_to_indices) {
208 if (it.second.size() != domain_size) continue;
209 (*tuples)[it.second.front()][i] = kTableAnyValue;
210 to_remove.insert(to_remove.end(), it.second.begin() + 1, it.second.end());
211 }
212 std::sort(to_remove.begin(), to_remove.end(), std::greater<int>());
213 for (const int t : to_remove) {
214 (*tuples)[t] = tuples->back();
215 tuples->pop_back();
216 }
217 }
218}
219
220namespace {
221
222// We will call FullyCompressTuplesRecursive() for a set of prefixes of the
223// original tuples, each having the same suffix (in reversed_suffix).
224//
225// For such set, we will compress it on the last variable of the prefixes. We
226// will then for each unique compressed set of value of that variable, call
227// a new FullyCompressTuplesRecursive() on the corresponding subset.
228void FullyCompressTuplesRecursive(
229 absl::Span<const int64_t> domain_sizes,
230 absl::Span<std::vector<int64_t>> tuples,
231 std::vector<absl::InlinedVector<int64_t, 2>>* reversed_suffix,
232 std::vector<std::vector<absl::InlinedVector<int64_t, 2>>>* output) {
233 struct TempData {
234 absl::InlinedVector<int64_t, 2> values;
235 int index;
236
237 bool operator<(const TempData& other) const {
238 return values < other.values;
239 }
240 };
241 std::vector<TempData> temp_data;
242
243 CHECK(!tuples.empty());
244 CHECK(!tuples[0].empty());
245 const int64_t domain_size = domain_sizes[tuples[0].size() - 1];
246
247 // Sort tuples and regroup common prefix in temp_data.
248 std::sort(tuples.begin(), tuples.end());
249 for (int i = 0; i < tuples.size();) {
250 const int start = i;
251 temp_data.push_back({{tuples[start].back()}, start});
252 tuples[start].pop_back();
253 for (++i; i < tuples.size(); ++i) {
254 const int64_t v = tuples[i].back();
255 tuples[i].pop_back();
256 if (tuples[i] == tuples[start]) {
257 temp_data.back().values.push_back(v);
258 } else {
259 tuples[i].push_back(v);
260 break;
261 }
262 }
263
264 // If one of the value is the special value kTableAnyValue, we convert
265 // it to the "empty means any" format.
266 for (const int64_t v : temp_data.back().values) {
267 if (v == kTableAnyValue) {
268 temp_data.back().values.clear();
269 break;
270 }
271 }
272 gtl::STLSortAndRemoveDuplicates(&temp_data.back().values);
273
274 // If values cover the whole domain, we clear the vector. This allows to
275 // use less space and avoid creating unneeded clauses.
276 if (temp_data.back().values.size() == domain_size) {
277 temp_data.back().values.clear();
278 }
279 }
280
281 if (temp_data.size() == 1) {
282 output->push_back({});
283 for (const int64_t v : tuples[temp_data[0].index]) {
284 if (v == kTableAnyValue) {
285 output->back().push_back({});
286 } else {
287 output->back().push_back({v});
288 }
289 }
290 output->back().push_back(temp_data[0].values);
291 for (int i = reversed_suffix->size(); --i >= 0;) {
292 output->back().push_back((*reversed_suffix)[i]);
293 }
294 return;
295 }
296
297 // Sort temp_data and make recursive call for all tuples that share the
298 // same suffix.
299 std::sort(temp_data.begin(), temp_data.end());
300 std::vector<std::vector<int64_t>> temp_tuples;
301 for (int i = 0; i < temp_data.size();) {
302 reversed_suffix->push_back(temp_data[i].values);
303 const int start = i;
304 temp_tuples.clear();
305 for (; i < temp_data.size(); i++) {
306 if (temp_data[start].values != temp_data[i].values) break;
307 temp_tuples.push_back(tuples[temp_data[i].index]);
308 }
309 FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(temp_tuples),
310 reversed_suffix, output);
311 reversed_suffix->pop_back();
312 }
313}
314
315} // namespace
316
317// TODO(user): We can probably reuse the tuples memory always and never create
318// new one. We should also be able to code an iterative version of this. Note
319// however that the recursion level is bounded by the number of columns which
320// should be small.
321std::vector<std::vector<absl::InlinedVector<int64_t, 2>>> FullyCompressTuples(
322 absl::Span<const int64_t> domain_sizes,
323 std::vector<std::vector<int64_t>>* tuples) {
324 std::vector<absl::InlinedVector<int64_t, 2>> reversed_suffix;
325 std::vector<std::vector<absl::InlinedVector<int64_t, 2>>> output;
326 FullyCompressTuplesRecursive(domain_sizes, absl::MakeSpan(*tuples),
327 &reversed_suffix, &output);
328 return output;
329}
330
331// TODO(user): Note that if we have duplicate variables controlling different
332// time point, this might not reach the fixed point. Fix? it is not that
333// important as the expansion take care of this case anyway.
334void PropagateAutomaton(const AutomatonConstraintProto& proto,
335 const PresolveContext& context,
336 std::vector<absl::flat_hash_set<int64_t>>* states,
337 std::vector<absl::flat_hash_set<int64_t>>* labels) {
338 const int n = proto.exprs_size();
339 const absl::flat_hash_set<int64_t> final_states(
340 {proto.final_states().begin(), proto.final_states().end()});
341
342 labels->clear();
343 labels->resize(n);
344 states->clear();
345 states->resize(n + 1);
346 (*states)[0].insert(proto.starting_state());
347
348 // Forward pass.
349 for (int time = 0; time < n; ++time) {
350 for (int t = 0; t < proto.transition_tail_size(); ++t) {
351 const int64_t tail = proto.transition_tail(t);
352 const int64_t label = proto.transition_label(t);
353 const int64_t head = proto.transition_head(t);
354 if (!(*states)[time].contains(tail)) continue;
355 if (!context.DomainContains(proto.exprs(time), label)) continue;
356 if (time == n - 1 && !final_states.contains(head)) continue;
357 (*labels)[time].insert(label);
358 (*states)[time + 1].insert(head);
359 }
360 }
361
362 // Backward pass.
363 for (int time = n - 1; time >= 0; --time) {
364 absl::flat_hash_set<int64_t> new_states;
365 absl::flat_hash_set<int64_t> new_labels;
366 for (int t = 0; t < proto.transition_tail_size(); ++t) {
367 const int64_t tail = proto.transition_tail(t);
368 const int64_t label = proto.transition_label(t);
369 const int64_t head = proto.transition_head(t);
370
371 if (!(*states)[time].contains(tail)) continue;
372 if (!(*labels)[time].contains(label)) continue;
373 if (!(*states)[time + 1].contains(head)) continue;
374 new_labels.insert(label);
375 new_states.insert(tail);
376 }
377 (*labels)[time].swap(new_labels);
378 (*states)[time].swap(new_states);
379 }
380}
381
382} // namespace sat
383} // namespace operations_research
void UpdateRuleStats(const std::string &name, int num_times=1)
bool DomainContains(int ref, int64_t value) const
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition stl_util.h:58
constexpr int64_t kTableAnyValue
void CompressTuples(absl::Span< const int64_t > domain_sizes, std::vector< std::vector< int64_t > > *tuples)
std::vector< std::vector< absl::InlinedVector< int64_t, 2 > > > FullyCompressTuples(absl::Span< const int64_t > domain_sizes, std::vector< std::vector< int64_t > > *tuples)
int64_t GetInnerVarValue(const LinearExpressionProto &expr, int64_t value)
int64_t AffineExpressionValueAt(const LinearExpressionProto &expr, int64_t value)
Evaluates an affine expression at the given value.
void PropagateAutomaton(const AutomatonConstraintProto &proto, const PresolveContext &context, std::vector< absl::flat_hash_set< int64_t > > *states, std::vector< absl::flat_hash_set< int64_t > > *labels)
Fills and propagates the set of reachable states/labels.
void CanonicalizeTable(PresolveContext *context, ConstraintProto *ct)
In SWIG mode, we don't want anything besides these top-level includes.