Google OR-Tools v9.14
a fast and portable software suite for combinatorial optimization
Loading...
Searching...
No Matches
elemental_export_model.cc
Go to the documentation of this file.
1// Copyright 2010-2025 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <array>
15#include <cstdint>
16#include <limits>
17#include <optional>
18#include <utility>
19#include <vector>
20
21#include "absl/algorithm/container.h"
22#include "absl/container/flat_hash_map.h"
23#include "absl/container/flat_hash_set.h"
24#include "absl/log/check.h"
25#include "absl/status/status.h"
26#include "absl/strings/string_view.h"
27#include "absl/types/span.h"
40
42
43namespace {
44
45constexpr int32_t kInt32Max = std::numeric_limits<int32_t>::max();
46
47absl::Status CanExportToProto(int64_t num_entries) {
48 if (num_entries > kInt32Max) {
50 << "Cannot export to proto, RepeatedField can hold at most "
51 "std::numeric_limits<int32_t>::max() = 2**31-1 = 2147483647 "
52 "entries "
53 "but found: "
54 << num_entries << " entries";
55 }
56 return absl::OkStatus();
57}
58
59// Invokes fn<i>() -> Status sequentially for i in [0..n) and returns the first
60// error (stopping early) or an OK status if every invocation succeeds.
61//
62// NOTE: move arrays.h if we need to reuse this, unit test:
63// https://paste.googleplex.com/5462207396839424
64template <int n, typename Fn>
65absl::Status ForEachIndexUntilError(Fn&& fn) {
66 absl::Status result = absl::OkStatus();
67 // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
68 ForEachIndex<n>([&result, &fn]<int i>() {
69 if (!result.ok()) {
70 return;
71 }
72 result.Update(fn.template operator()<i>());
73 });
74 return result;
75}
76
77// Applies `fn` on each value for each attribute type until an error is found,
78// then returns that error, or OK if no error is found. `fn` must have a
79// overload set of `operator<AttrType a>() -> Status` that accepts a `Type<i>`
80// for `i` in `0..kSize`.
81//
82// NOTE: move derived_data.h if we need to reuse this, unit test:
83// https://paste.googleplex.com/4999175629701120
84template <typename Fn>
85absl::Status ForEachAttrUntilError(Fn&& fn) {
86 absl::Status result;
87 AllAttrs::ForEachAttr([&result, &fn](auto attr) {
88 if (!result.ok()) {
89 return;
90 }
91 result.Update(fn(attr));
92 });
93 return result;
94}
95
97// ExportModelProto
99
100// Returns an error if there are more than 2**31-1 elements of any element type
101// in `model`.
102absl::Status ValidateElementsFitInProto(const Elemental& model) {
103 return ForEachIndexUntilError<kNumElements>(
104 // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
105 [&model]<int e>() -> absl::Status {
106 constexpr auto element_type = static_cast<ElementType>(e);
107 RETURN_IF_ERROR(CanExportToProto(model.NumElements(element_type)))
108 << "too many elements of type: " << element_type;
109 return absl::OkStatus();
110 });
111}
112
113// Returns an error if any attribute has more than 2**31-1 keys with a
114// non-default value. We only check attributes with a key size >= 2, as we have
115// already validated that the Elements fit in proto (which implies attr1s will
116// fit).
117absl::Status ValidateAttrsFitInProto(const Elemental& model) {
118 return ForEachAttrUntilError([&model](auto attr) -> absl::Status {
119 if constexpr (GetAttrKeySize<decltype(attr)>() > 1) {
120 RETURN_IF_ERROR(CanExportToProto(model.AttrNumNonDefaults(attr)))
121 << "too many non-default values for attribute: " << attr;
122 }
123 return absl::OkStatus();
124 });
125}
126
127// Returns an error if model will not fit into a ModelProto.
128absl::Status ValidateModelFitsInProto(const Elemental& model) {
129 RETURN_IF_ERROR(ValidateElementsFitInProto(model));
130 RETURN_IF_ERROR(ValidateAttrsFitInProto(model));
131 return absl::OkStatus();
132}
133
134template <typename T>
135std::vector<T> Sorted(std::vector<T> vec) {
136 absl::c_sort(vec);
137 return vec;
138}
139
140template <typename T>
141std::vector<T> SortSet(const absl::flat_hash_set<T>& s) {
142 return Sorted(std::vector<T>{s.begin(), s.end()});
143}
144
145template <ElementType e>
147 absl::c_sort(vec.container());
148 return vec;
149}
150
151// The caller must ensure that keys has at most 2**31-1 elements.
152std::optional<SparseDoubleVectorProto> ExportSparseDoubleVector(
153 const Elemental& elemental, DoubleAttr1 double_attr,
154 absl::Span<const AttrKey<1>> keys) {
155 if (keys.empty()) {
156 return std::nullopt;
157 }
158 CHECK_LE(keys.size(), kInt32Max);
160 const int32_t num_keys = static_cast<int32_t>(keys.size());
161 result.mutable_ids()->Reserve(num_keys);
162 result.mutable_values()->Reserve(num_keys);
163 for (const AttrKey<1> key : keys) {
164 result.add_ids(key[0]);
165 result.add_values(elemental.GetAttr<Elemental::UBPolicy>(double_attr, key));
166 }
167 return result;
168}
169
170std::optional<SparseDoubleVectorProto> ExportSparseDoubleVector(
171 const Elemental& elemental, DoubleAttr1 double_attr) {
172 return ExportSparseDoubleVector(
173 elemental, double_attr, Sorted(elemental.AttrNonDefaults(double_attr)));
174}
175
176// DAttr2 will be DoubleAttr2 or SymmetricDoubleAttr2.
177//
178// The caller is responsible for ensuring that there are at most 2**31-1 keys,
179// otherwise UB/crash, e.g. by calling ValidateModelFitsInProto().
180//
181// Keys must be sorted!
182template <typename DAttr2>
183std::optional<SparseDoubleMatrixProto> ExportSparseDoubleMatrix(
184 const Elemental& elemental, const DAttr2 attr,
185 const std::vector<AttrKeyFor<DAttr2>>& keys) {
186 static_assert(GetAttrKeySize<DAttr2>() == 2,
187 "Attribute must have key size two");
188 static_assert(
189 std::is_same_v<double, typename AttrTypeDescriptorT<DAttr2>::ValueType>,
190 "Attribute must be double valued");
191 if (keys.empty()) {
192 return std::nullopt;
193 }
194 CHECK_LE(keys.size(), kInt32Max);
196 // See function level comment, the caller must ensure this is safe.
197 const int nnz = static_cast<int>(keys.size());
198 result.mutable_row_ids()->Reserve(nnz);
199 result.mutable_column_ids()->Reserve(nnz);
200 result.mutable_coefficients()->Reserve(nnz);
201 for (const AttrKeyFor<DAttr2> key : keys) {
202 result.add_row_ids(key[0]);
203 result.add_column_ids(key[1]);
204 // We're using `UBPolicy` for `GetAttr` since we just obtained `keys` from
205 // the model.
206 result.add_coefficients(elemental.GetAttr<Elemental::UBPolicy>(attr, key));
207 }
208 return result;
209}
210
211// It is the caller's responsibility to ensure that the size of the slice is at
212// most 2**31-1.
213template <int key_index>
214std::optional<SparseDoubleVectorProto> ExportSparseDoubleMatrixSlice(
215 const Elemental& elemental, const DoubleAttr2 attr,
216 const int64_t slice_element_id) {
217 std::vector<AttrKey<2>> slice =
218 elemental.Slice<key_index>(attr, slice_element_id);
219 if (slice.empty()) {
220 return std::nullopt;
221 }
222 CHECK_LE(slice.size(), kInt32Max);
223 const int slice_size = static_cast<int>(slice.size());
224 absl::c_sort(slice);
226 vec.mutable_ids()->Reserve(slice_size);
227 vec.mutable_values()->Reserve(slice_size);
228 for (const AttrKey<2> key : slice) {
229 vec.add_ids(key.RemoveElement<key_index>()[0]);
230 vec.add_values(elemental.GetAttr<Elemental::UBPolicy>(attr, key));
231 }
232 return vec;
233}
234
235template <typename DAttr2>
236std::optional<SparseDoubleMatrixProto> ExportSparseDoubleMatrix(
237 const Elemental& elemental, const DAttr2 attr) {
238 return ExportSparseDoubleMatrix(elemental, attr,
239 Sorted(elemental.AttrNonDefaults(attr)));
240}
241
242std::optional<VariablesProto> ExportVariables(
243 const Elemental& elemental,
245 const bool remove_names) {
246 if (var_ids.empty()) {
247 return std::nullopt;
248 }
249 VariablesProto vars_proto;
250 // Safe because we have called ValidateModelFitsInProto().
251 const int num_vars = static_cast<int>(var_ids.size());
252 vars_proto.mutable_ids()->Reserve(num_vars);
253 vars_proto.mutable_integers()->Reserve(num_vars);
254 vars_proto.mutable_lower_bounds()->Reserve(num_vars);
255 vars_proto.mutable_upper_bounds()->Reserve(num_vars);
256 if (!remove_names) {
257 vars_proto.mutable_names()->Reserve(num_vars);
258 }
259 for (const VariableId var : var_ids) {
260 vars_proto.add_ids(var.value());
261 // We're using `UBPolicy` for `GetAttr` since we just obtained `var_ids`
262 // from the model.
263 vars_proto.add_integers(elemental.GetAttr<Elemental::UBPolicy>(
264 BoolAttr1::kVarInteger, AttrKey(var)));
265 vars_proto.add_lower_bounds(elemental.GetAttr<Elemental::UBPolicy>(
266 DoubleAttr1::kVarLb, AttrKey(var)));
267 vars_proto.add_upper_bounds(elemental.GetAttr<Elemental::UBPolicy>(
268 DoubleAttr1::kVarUb, AttrKey(var)));
269 if (!remove_names) {
270 const auto name = elemental.GetElementName(var);
271 CHECK_OK(name);
272 vars_proto.add_names(*name);
273 }
274 }
275 return vars_proto;
276}
277
278std::optional<LinearConstraintsProto> ExportLinearConstraints(
279 const Elemental& elemental,
281 const bool remove_names) {
282 if (lin_con_ids.empty()) {
283 return std::nullopt;
284 }
285 LinearConstraintsProto lin_cons_proto;
286 // Safe because we have called ValidateModelFitsInProto().
287 const int num_lin_cons = static_cast<int>(lin_con_ids.size());
288
289 lin_cons_proto.mutable_ids()->Reserve(num_lin_cons);
290 lin_cons_proto.mutable_lower_bounds()->Reserve(num_lin_cons);
291 lin_cons_proto.mutable_upper_bounds()->Reserve(num_lin_cons);
292 if (!remove_names) {
293 lin_cons_proto.mutable_names()->Reserve(num_lin_cons);
294 }
295 for (const LinearConstraintId lin_con : lin_con_ids) {
296 lin_cons_proto.add_ids(lin_con.value());
297 // We're using `UBPolicy` for `GetAttr` since we just obtained
298 // `lin_con_ids` from the model.
299 lin_cons_proto.add_lower_bounds(elemental.GetAttr<Elemental::UBPolicy>(
300 DoubleAttr1::kLinConLb, AttrKey(lin_con)));
301 lin_cons_proto.add_upper_bounds(elemental.GetAttr<Elemental::UBPolicy>(
302 DoubleAttr1::kLinConUb, AttrKey(lin_con)));
303 if (!remove_names) {
304 const auto name = elemental.GetElementName(lin_con);
305 CHECK_OK(name);
306 lin_cons_proto.add_names(*name);
307 }
308 }
309 return lin_cons_proto;
310}
311
312// This function will crash if there are more than 2**31 elements in
313// quad_con_ids.
314absl::flat_hash_map<QuadraticConstraintId, QuadraticConstraintProto>
315ExportQuadraticConstraints(
316 const Elemental& elemental,
318 const bool remove_names) {
319 absl::flat_hash_map<QuadraticConstraintId, QuadraticConstraintProto> result;
320 CHECK_LE(quad_con_ids.size(), kInt32Max);
321 for (const QuadraticConstraintId id : quad_con_ids) {
323 if (!remove_names) {
324 const auto name = elemental.GetElementName(id);
325 CHECK_OK(name);
326 quad_con.set_name(*name);
327 }
328 quad_con.set_lower_bound(elemental.GetAttr<Elemental::UBPolicy>(
329 DoubleAttr1::kQuadConLb, AttrKey(id)));
330 quad_con.set_upper_bound(elemental.GetAttr<Elemental::UBPolicy>(
331 DoubleAttr1::kQuadConUb, AttrKey(id)));
332 if (std::optional<SparseDoubleVectorProto> lin_coefs =
333 ExportSparseDoubleMatrixSlice<0>(
334 elemental, DoubleAttr2::kQuadConLinCoef, id.value());
335 lin_coefs.has_value()) {
336 *quad_con.mutable_linear_terms() = *std::move(lin_coefs);
337 }
338 if (std::vector<AttrKey<3, ElementSymmetry<1, 2>>> quad_con_quad_coefs =
339 elemental.Slice<0>(SymmetricDoubleAttr3::kQuadConQuadCoef,
340 id.value());
341 !quad_con_quad_coefs.empty()) {
342 absl::c_sort(quad_con_quad_coefs);
343 SparseDoubleMatrixProto& quad_coef_proto =
344 *quad_con.mutable_quadratic_terms();
345 for (const AttrKey<3, ElementSymmetry<1, 2>> key : quad_con_quad_coefs) {
346 quad_coef_proto.add_row_ids(key[1]);
347 quad_coef_proto.add_column_ids(key[2]);
348 quad_coef_proto.add_coefficients(elemental.GetAttr<Elemental::UBPolicy>(
349 SymmetricDoubleAttr3::kQuadConQuadCoef, key));
350 }
351 }
352 auto [it, inserted] = result.insert({id, std::move(quad_con)});
353 CHECK(inserted);
354 }
355 return result;
356}
357
358absl::flat_hash_map<IndicatorConstraintId, IndicatorConstraintProto>
359ExportIndicatorConstraints(
360 const Elemental& elemental,
362 const bool remove_names) {
363 absl::flat_hash_map<IndicatorConstraintId, IndicatorConstraintProto> result;
364 CHECK_LE(ind_con_ids.size(), kInt32Max);
365 for (const IndicatorConstraintId id : ind_con_ids) {
367 if (!remove_names) {
368 const auto name = elemental.GetElementName(id);
369 CHECK_OK(name);
370 ind_con.set_name(*name);
371 }
372 ind_con.set_lower_bound(elemental.GetAttr<Elemental::UBPolicy>(
373 DoubleAttr1::kIndConLb, AttrKey(id)));
374 ind_con.set_upper_bound(elemental.GetAttr<Elemental::UBPolicy>(
375 DoubleAttr1::kIndConUb, AttrKey(id)));
376 if (std::optional<SparseDoubleVectorProto> lin_coefs =
377 ExportSparseDoubleMatrixSlice<0>(
378 elemental, DoubleAttr2::kIndConLinCoef, id.value());
379 lin_coefs.has_value()) {
380 *ind_con.mutable_expression() = *std::move(lin_coefs);
381 }
382 ind_con.set_activate_on_zero(elemental.GetAttr<Elemental::UBPolicy>(
383 BoolAttr1::kIndConActivateOnZero, AttrKey(id)));
384 if (elemental.AttrIsNonDefault<Elemental::UBPolicy>(
385 VariableAttr1::kIndConIndicator, AttrKey(id))) {
386 ind_con.set_indicator_id(
387 elemental
388 .GetAttr<Elemental::UBPolicy>(VariableAttr1::kIndConIndicator,
389 AttrKey(id))
390 .value());
391 }
392 CHECK(result.insert({id, std::move(ind_con)}).second);
393 }
394 return result;
395}
396
397std::optional<ObjectiveProto> ExportObjective(const Elemental& elemental,
398 const bool remove_names) {
399 const bool has_offset =
400 elemental.AttrIsNonDefault(DoubleAttr0::kObjOffset, AttrKey());
401 const bool has_maximize =
402 elemental.AttrIsNonDefault(BoolAttr0::kMaximize, AttrKey());
403 const bool has_priority =
404 elemental.AttrIsNonDefault(IntAttr0::kObjPriority, AttrKey());
405 // We have less than 2**31 elements from existing validation.
406 std::optional<SparseDoubleVectorProto> lin_obj_vec =
407 ExportSparseDoubleVector(elemental, DoubleAttr1::kObjLinCoef);
408 // We have less than 2**31 elements from existing validation.
409 std::optional<SparseDoubleMatrixProto> quad_obj_mat =
410 ExportSparseDoubleMatrix(elemental, SymmetricDoubleAttr2::kObjQuadCoef);
411 const bool has_name =
412 !remove_names && !elemental.primary_objective_name().empty();
413 if (!has_offset && !has_maximize && !has_priority &&
414 !lin_obj_vec.has_value() && !quad_obj_mat.has_value() && !has_name) {
415 return std::nullopt;
416 }
417 ObjectiveProto result;
418 if (!remove_names) {
419 result.set_name(elemental.primary_objective_name());
420 }
421 result.set_maximize(elemental.GetAttr(BoolAttr0::kMaximize, AttrKey()));
422 result.set_offset(elemental.GetAttr(DoubleAttr0::kObjOffset, AttrKey()));
423 result.set_priority(elemental.GetAttr(IntAttr0::kObjPriority, AttrKey()));
424 if (lin_obj_vec.has_value()) {
425 *result.mutable_linear_coefficients() = *std::move(lin_obj_vec);
426 }
427 if (quad_obj_mat.has_value()) {
428 *result.mutable_quadratic_coefficients() = *std::move(quad_obj_mat);
429 }
430 return result;
431}
432
433absl::StatusOr<ObjectiveProto> ExportAuxiliaryObjective(
434 const Elemental& elemental, const AuxiliaryObjectiveId id,
435 const bool remove_names) {
436 ObjectiveProto result;
437 if (!remove_names) {
438 ASSIGN_OR_RETURN(const absl::string_view name,
439 elemental.GetElementName(id));
440 result.set_name(name);
441 }
442 result.set_maximize(
443 elemental.GetAttr(BoolAttr1::kAuxObjMaximize, AttrKey(id)));
444 result.set_offset(elemental.GetAttr(DoubleAttr1::kAuxObjOffset, AttrKey(id)));
445 result.set_priority(
446 elemental.GetAttr(IntAttr1::kAuxObjPriority, AttrKey(id)));
447 if (std::optional<SparseDoubleVectorProto> lin_coefs =
448 ExportSparseDoubleMatrixSlice<0>(
449 elemental, DoubleAttr2::kAuxObjLinCoef, id.value());
450 lin_coefs.has_value()) {
451 *result.mutable_linear_coefficients() = *std::move(lin_coefs);
452 }
453 return result;
454}
455
456absl::StatusOr<ModelProto> ExportModelProto(const Elemental& elemental,
457 const bool remove_names) {
458 RETURN_IF_ERROR(ValidateModelFitsInProto(elemental));
459 ModelProto result;
460 if (!remove_names) {
461 result.set_name(elemental.model_name());
462 }
463 if (auto vars = ExportVariables(
464 elemental,
465 Sorted(elemental.AllElements<ElementType::kVariable>()).view(),
466 remove_names);
467 vars.has_value()) {
468 *result.mutable_variables() = *std::move(vars);
469 }
470
471 // ObjectiveProto
472 if (auto obj = ExportObjective(elemental, remove_names); obj.has_value()) {
473 *result.mutable_objective() = *std::move(obj);
474 }
475 // Auxiliary objectives
476 for (const AuxiliaryObjectiveId aux_obj_id :
477 elemental.AllElements<ElementType::kAuxiliaryObjective>()) {
479 ((*result.mutable_auxiliary_objectives())[aux_obj_id.value()]),
480 ExportAuxiliaryObjective(elemental, aux_obj_id, remove_names));
481 }
482 // LinearConstraintsProto
483 if (auto lin_cons = ExportLinearConstraints(
484 elemental,
485 Sorted(elemental.AllElements<ElementType::kLinearConstraint>())
486 .view(),
487 remove_names);
488 lin_cons.has_value()) {
489 *result.mutable_linear_constraints() = *std::move(lin_cons);
490 }
491
492 // Linear constraint matrix proto
493 if (auto mat = ExportSparseDoubleMatrix(
494 elemental, DoubleAttr2::kLinConCoef,
495 Sorted(elemental.AttrNonDefaults(DoubleAttr2::kLinConCoef)));
496 mat.has_value()) {
497 *result.mutable_linear_constraint_matrix() = *std::move(mat);
498 }
499
500 // Quadratic constraints
501 for (auto& [id, quad_con_coef] : ExportQuadraticConstraints(
502 elemental,
503 Sorted(elemental.AllElements<ElementType::kQuadraticConstraint>())
504 .view(),
505 remove_names)) {
506 (*result.mutable_quadratic_constraints())[id.value()] =
507 std::move(quad_con_coef);
508 }
509 // Indicator constraints
510 for (auto& [id, ind_con] : ExportIndicatorConstraints(
511 elemental,
512 Sorted(elemental.AllElements<ElementType::kIndicatorConstraint>())
513 .view(),
514 remove_names)) {
515 (*result.mutable_indicator_constraints())[id.value()] = std::move(ind_con);
516 }
517
518 return result;
519}
520
521} // namespace
522
523absl::StatusOr<ModelProto> Elemental::ExportModel(
524 const bool remove_names) const {
525 // It intentional that that this function is implemented without access to the
526 // private API of elemental. This allows us to change the implementation
527 // elemental without breaking the proto export code.
528 return ExportModelProto(*this, remove_names);
529}
530
532// ExportModelUpdateProto
534
535namespace {
536
537// Returns an error if there are more than 2**31-1 new elements or deleted
538// elements of any element type.
539absl::Status ValidateElementUpdatesFitInProto(
540 const Diff& diff,
541 const std::array<std::vector<int64_t>, kNumElements>& new_elements) {
542 return ForEachIndexUntilError<kNumElements>(
543 // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
544 [&diff, &new_elements]<int e>() -> absl::Status {
545 constexpr auto element_type = static_cast<ElementType>(e);
547 CanExportToProto(diff.deleted_elements(element_type).size()))
548 << "too many deleted elements of type: " << element_type;
549 RETURN_IF_ERROR(CanExportToProto(new_elements[e].size()))
550 << "too many new elements of type: " << element_type;
551 return absl::OkStatus();
552 });
553}
554
555// Returns an error if the number number of tracked modifications exceeds
556// 2**31-1 for any attribute.
557//
558// TODO(b/372411343): this is too conservative for quadratic constraints.
559absl::Status ValidateAttrUpdatesFitInProto(const Diff& diff) {
560 return ForEachAttrUntilError([&diff](auto attr) -> absl::Status {
561 RETURN_IF_ERROR(CanExportToProto(diff.modified_keys(attr).size()))
562 << "too many modifications for attribute: " << attr;
563 return absl::OkStatus();
564 });
565}
566
567// Checks some necessary (but not sufficient) conditions that we can build a
568// ModelUpdateProto for this diff.
569//
570// Validates that:
571// * For each element type, we deletes at most 2**31-1 existing elements.
572// * For each element type, we add at most 2**31-1 new elements
573// * For each attribute, we update at most 2**31-1 keys on existing elements.
574//
575// This validation does not ensure we can actually build a ModelUpdateProto,
576// further validation is required, some of which is specific to how
577// ModelUpdateProto stores attributes and elements. For example:
578// * For any attribute with key size >= 2, we have not checked that the number
579// of keys containing a new element is at most 2**31-1.
580// * The linear objective coefficients and linear constraint coefficients
581// store both updates to keys on existing elements and attribute values for
582// keys containing a new elements in the same repeated field, so we need to
583// check that their combined size is at most 2**31-1.
584absl::Status ValidateModelUpdateFitsInProto(
585 const Diff& diff,
586 const std::array<std::vector<int64_t>, kNumElements>& new_elements) {
587 RETURN_IF_ERROR(ValidateElementUpdatesFitInProto(diff, new_elements));
588 RETURN_IF_ERROR(ValidateAttrUpdatesFitInProto(diff));
589 return absl::OkStatus();
590}
591
592// No need to return optional, repeated fields have no presence.
593template <ElementType e>
594google::protobuf::RepeatedField<int64_t> DeletedIdsSorted(const Diff& diff) {
595 const std::vector<int64_t> sorted_del_vars =
596 SortSet(diff.deleted_elements(e));
597 return {sorted_del_vars.begin(), sorted_del_vars.end()};
598}
599
600std::optional<SparseDoubleVectorProto> ExportAttrDiff(
601 const Elemental& elemental, DoubleAttr1 a, const Diff& diff) {
602 std::vector<AttrKey<1>> keys = elemental.ModifiedKeysThatExist(a, diff);
603 if (keys.empty()) {
604 return std::nullopt;
605 }
606 absl::c_sort(keys);
608 // NOTE: this cast is safe, we called ValidateModelUpdateFitsInProto().
609 const int num_keys = static_cast<int>(keys.size());
610 result.mutable_ids()->Reserve(num_keys);
611 result.mutable_values()->Reserve(num_keys);
612 for (const auto key : keys) {
613 result.add_ids(key[0]);
614 result.add_values(elemental.GetAttr(a, key));
615 }
616 return result;
617}
618
619absl::StatusOr<std::optional<SparseDoubleVectorProto>> ExportLinObjCoefUpdate(
620 const Elemental& elemental, const Diff& diff,
621 const ElementIdsSpan<ElementType::kVariable> new_var_ids_sorted) {
622 std::vector<AttrKey<1>> keys =
623 elemental.ModifiedKeysThatExist(DoubleAttr1::kObjLinCoef, diff);
624 absl::c_sort(keys);
625 // This is double the hashing we should be doing.
626 for (const VariableId id : new_var_ids_sorted) {
627 if (elemental.AttrIsNonDefault<Elemental::UBPolicy>(
628 DoubleAttr1::kObjLinCoef, AttrKey(id))) {
629 keys.push_back(AttrKey(id));
630 }
631 }
632 RETURN_IF_ERROR(CanExportToProto(keys.size()))
633 << "cannot export linear objective coefficients in model update";
634 return ExportSparseDoubleVector(elemental, DoubleAttr1::kObjLinCoef, keys);
635}
636
637absl::StatusOr<std::optional<SparseDoubleMatrixProto>> ExportQuadObjCoefUpdate(
638 const Elemental& elemental, const Diff& diff,
639 const ElementIdsSpan<ElementType::kVariable> new_var_ids_sorted) {
641 std::vector<Key> keys =
642 elemental.ModifiedKeysThatExist(SymmetricDoubleAttr2::kObjQuadCoef, diff);
643 if (!new_var_ids_sorted.empty()) {
644 const int64_t smallest_key = new_var_ids_sorted[0].value();
645 // This is ~double the hashing we should be doing.
646 for (const VariableId id : new_var_ids_sorted) {
647 for (const Key key :
648 elemental.Slice<0>(SymmetricDoubleAttr2::kObjQuadCoef, id.value())) {
649 if (key[0] < smallest_key || key[1] < smallest_key ||
650 key[0] == id.value()) {
651 keys.push_back(key);
652 }
653 }
654 }
655 }
656 RETURN_IF_ERROR(CanExportToProto(keys.size()))
657 << "cannot export linear objective coefficients in model update";
658 absl::c_sort(keys);
659 return ExportSparseDoubleMatrix(elemental, SymmetricDoubleAttr2::kObjQuadCoef,
660 keys);
661}
662
663std::optional<SparseBoolVectorProto> ExportAttrDiff(const Elemental& elemental,
664 const BoolAttr1 a,
665 const Diff& diff) {
666 std::vector<AttrKey<1>> keys = elemental.ModifiedKeysThatExist(a, diff);
667 if (keys.empty()) {
668 return std::nullopt;
669 }
670 absl::c_sort(keys);
672 // NOTE: this cast is safe, we called ValidateModelUpdateFitsInProto().
673 const int num_keys = static_cast<int>(keys.size());
674 result.mutable_ids()->Reserve(num_keys);
675 result.mutable_values()->Reserve(num_keys);
676 for (const auto id : keys) {
677 result.add_ids(id[0]);
678 result.add_values(elemental.GetAttr(a, id));
679 }
680 return result;
681}
682
683std::vector<int64_t> ElementsSinceCheckpoint(const ElementType e,
684 const Elemental& elemental,
685 const Diff& diff) {
686 std::vector<int64_t> result;
687 for (int64_t id = diff.checkpoint(e); id < elemental.NextElementId(e); ++id) {
688 if (elemental.ElementExistsUntyped(e, id)) {
689 result.push_back(id);
690 }
691 }
692 return result;
693}
694
695std::array<std::vector<int64_t>, kNumElements> ElementsSinceCheckpointPerType(
696 const Elemental& elemental, const Diff& diff) {
697 std::array<std::vector<int64_t>, kNumElements> result;
698 for (const ElementType e : kElements) {
699 result[static_cast<int>(e)] = ElementsSinceCheckpoint(e, elemental, diff);
700 };
701 return result;
702}
703
704std::optional<VariableUpdatesProto> ExportVariableUpdates(
705 const Elemental& elemental, const Diff& diff) {
706 auto ubs = ExportAttrDiff(elemental, DoubleAttr1::kVarUb, diff);
707 auto lbs = ExportAttrDiff(elemental, DoubleAttr1::kVarLb, diff);
708 auto integers = ExportAttrDiff(elemental, BoolAttr1::kVarInteger, diff);
709 if (!ubs.has_value() && !lbs.has_value() && !integers.has_value()) {
710 return std::nullopt;
711 }
712 VariableUpdatesProto var_updates;
713 if (ubs.has_value()) {
714 *var_updates.mutable_upper_bounds() = *std::move(ubs);
715 }
716 if (lbs.has_value()) {
717 *var_updates.mutable_lower_bounds() = *std::move(lbs);
718 }
719 if (integers.has_value()) {
720 *var_updates.mutable_integers() = *std::move(integers);
721 }
722 return var_updates;
723}
724
725std::optional<LinearConstraintUpdatesProto> ExportLinearConstraintUpdates(
726 const Elemental& elemental, const Diff& diff) {
727 auto ubs = ExportAttrDiff(elemental, DoubleAttr1::kLinConUb, diff);
728 auto lbs = ExportAttrDiff(elemental, DoubleAttr1::kLinConLb, diff);
729 if (!ubs.has_value() && !lbs.has_value()) {
730 return std::nullopt;
731 }
732 LinearConstraintUpdatesProto lin_con_updates;
733 if (ubs.has_value()) {
734 *lin_con_updates.mutable_upper_bounds() = *std::move(ubs);
735 }
736 if (lbs.has_value()) {
737 *lin_con_updates.mutable_lower_bounds() = *std::move(lbs);
738 }
739 return lin_con_updates;
740}
741
742absl::StatusOr<std::optional<ObjectiveUpdatesProto>> ExportObjectiveUpdates(
743 const Elemental& elemental, const Diff& diff,
744 const ElementIdsSpan<ElementType::kVariable> new_var_ids) {
745 ASSIGN_OR_RETURN(std::optional<SparseDoubleVectorProto> lin_coef_updates,
746 ExportLinObjCoefUpdate(elemental, diff, new_var_ids));
747 ASSIGN_OR_RETURN(std::optional<SparseDoubleMatrixProto> quad_coef_updates,
748 ExportQuadObjCoefUpdate(elemental, diff, new_var_ids));
749 const bool maximize_modified =
750 diff.modified_keys(BoolAttr0::kMaximize).contains(AttrKey());
751 const bool offset_modified =
752 diff.modified_keys(DoubleAttr0::kObjOffset).contains(AttrKey());
753 const bool priority_modified =
754 diff.modified_keys(IntAttr0::kObjPriority).contains(AttrKey());
755 if (!lin_coef_updates.has_value() && !quad_coef_updates.has_value() &&
756 !maximize_modified && !offset_modified && !priority_modified) {
757 return std::nullopt;
758 }
759 ObjectiveUpdatesProto objective_updates;
760 if (maximize_modified) {
761 objective_updates.set_direction_update(
762 elemental.GetAttr(BoolAttr0::kMaximize, AttrKey()));
763 }
764 if (offset_modified) {
765 objective_updates.set_offset_update(
766 elemental.GetAttr(DoubleAttr0::kObjOffset, AttrKey()));
767 }
768 if (priority_modified) {
769 objective_updates.set_priority_update(
770 elemental.GetAttr(IntAttr0::kObjPriority, AttrKey()));
771 }
772 if (lin_coef_updates.has_value()) {
773 *objective_updates.mutable_linear_coefficients() =
774 *std::move(lin_coef_updates);
775 }
776 if (quad_coef_updates.has_value()) {
777 *objective_updates.mutable_quadratic_coefficients() =
778 *std::move(quad_coef_updates);
779 }
780 return objective_updates;
781}
782
783bool ModelUpdateIsEmpty(
784 const Diff& diff,
785 const std::array<std::vector<int64_t>, kNumElements> new_elements) {
786 for (const std::vector<int64_t>& els : new_elements) {
787 if (!els.empty()) {
788 return false;
789 }
790 }
791 bool is_empty = true;
792 for (const ElementType e : kElements) {
793 is_empty =
794 is_empty && diff.deleted_elements(static_cast<ElementType>(e)).empty();
795 }
796 // Subtle: we do not need to check for attribute modifications on a key
797 // containing a new element, as if there is a new element, we have already
798 // shown that update is non-empty.
799 AllAttrs::ForEachAttr([&diff, &is_empty](auto attr) {
800 is_empty = is_empty && diff.modified_keys(attr).empty();
801 });
802 return is_empty;
803}
804
805template <typename AttrType>
806absl::Status EnsureAttrModificationsEmpty(const Diff& diff, AttrType attr) {
807 if (!diff.modified_keys(attr).empty()) {
809 << "Modification for attribute " << attr
810 << " is not supported for ModelUpdateProto export.";
811 }
812 return absl::OkStatus();
813}
814
815absl::StatusOr<std::optional<QuadraticConstraintUpdatesProto>>
816ExportQuadraticConstraintsUpdates(
817 const Elemental& elemental, const Diff& diff,
819 const bool remove_names) {
820 // Quadratic constraints are currently immutable (beyond variable deletions)
821 RETURN_IF_ERROR(EnsureAttrModificationsEmpty(diff, DoubleAttr1::kQuadConLb));
822 RETURN_IF_ERROR(EnsureAttrModificationsEmpty(diff, DoubleAttr1::kQuadConUb));
824 EnsureAttrModificationsEmpty(diff, DoubleAttr2::kQuadConLinCoef));
825 RETURN_IF_ERROR(EnsureAttrModificationsEmpty(
826 diff, SymmetricDoubleAttr3::kQuadConQuadCoef));
827 auto deleted = DeletedIdsSorted<ElementType::kQuadraticConstraint>(diff);
828 if (deleted.empty() && new_quad_cons.empty()) {
829 return std::nullopt;
830 }
832 *result.mutable_deleted_constraint_ids() = std::move(deleted);
833 for (auto& [id, quad_con] :
834 ExportQuadraticConstraints(elemental, new_quad_cons, remove_names)) {
835 result.mutable_new_constraints()->insert({id.value(), std::move(quad_con)});
836 }
837 return result;
838}
839
840absl::StatusOr<std::optional<IndicatorConstraintUpdatesProto>>
841ExportIndicatorConstraintsUpdates(
842 const Elemental& elemental, const Diff& diff,
844 const bool remove_names) {
845 // Indicator constraints are currently immutable (beyond variable deletions)
847 EnsureAttrModificationsEmpty(diff, BoolAttr1::kIndConActivateOnZero));
849 EnsureAttrModificationsEmpty(diff, VariableAttr1::kIndConIndicator));
850 RETURN_IF_ERROR(EnsureAttrModificationsEmpty(diff, DoubleAttr1::kIndConLb));
851 RETURN_IF_ERROR(EnsureAttrModificationsEmpty(diff, DoubleAttr1::kIndConUb));
853 EnsureAttrModificationsEmpty(diff, DoubleAttr2::kIndConLinCoef));
854 auto deleted = DeletedIdsSorted<ElementType::kIndicatorConstraint>(diff);
855 if (deleted.empty() && new_ind_cons.empty()) {
856 return std::nullopt;
857 }
859 *result.mutable_deleted_constraint_ids() = std::move(deleted);
860 for (auto& [id, ind_con] :
861 ExportIndicatorConstraints(elemental, new_ind_cons, remove_names)) {
862 result.mutable_new_constraints()->insert({id.value(), std::move(ind_con)});
863 }
864 return result;
865}
866
867absl::StatusOr<std::optional<AuxiliaryObjectivesUpdatesProto>>
868ExportAuxiliaryObjectivesUpdates(
869 const Elemental& elemental, const Diff& diff,
872 const bool remove_names) {
874 auto deleted = DeletedIdsSorted<ElementType::kAuxiliaryObjective>(diff);
875 // Look for modifications to existing objectives, if we have any existing
876 // auxiliary objectives.
877 if (diff.checkpoint(ElementType::kAuxiliaryObjective) > 0) {
878 google::protobuf::Map<int64_t, ObjectiveUpdatesProto>& mods =
879 *result.mutable_objective_updates();
880 for (const AttrKey<1> aux_obj :
881 diff.modified_keys(BoolAttr1::kAuxObjMaximize)) {
882 mods[aux_obj[0]].set_direction_update(
883 elemental.GetAttr<Elemental::UBPolicy>(BoolAttr1::kAuxObjMaximize,
884 aux_obj));
885 }
886 for (const AttrKey<1> aux_obj :
887 diff.modified_keys(IntAttr1::kAuxObjPriority)) {
888 mods[aux_obj[0]].set_priority_update(
889 elemental.GetAttr<Elemental::UBPolicy>(IntAttr1::kAuxObjPriority,
890 aux_obj));
891 }
892 for (const AttrKey<1> aux_obj :
893 diff.modified_keys(DoubleAttr1::kAuxObjOffset)) {
894 mods[aux_obj[0]].set_offset_update(elemental.GetAttr<Elemental::UBPolicy>(
895 DoubleAttr1::kAuxObjOffset, aux_obj));
896 }
897 absl::flat_hash_map<int64_t, std::vector<std::pair<int64_t, double>>>
898 lin_con_updates;
899 for (const AttrKey<2> aux_lin_obj_var :
900 elemental.ModifiedKeysThatExist(DoubleAttr2::kAuxObjLinCoef, diff)) {
901 lin_con_updates[aux_lin_obj_var[0]].push_back(
902 {aux_lin_obj_var[1],
903 elemental.GetAttr<Elemental::UBPolicy>(DoubleAttr2::kAuxObjLinCoef,
904 aux_lin_obj_var)});
905 }
906 for (const VariableId new_var : new_vars) {
907 for (const AttrKey<2> aux_lin_obj_var :
908 elemental.Slice<1>(DoubleAttr2::kAuxObjLinCoef, new_var.value())) {
909 const int64_t aux_obj = aux_lin_obj_var[0];
910 if (aux_obj >= diff.checkpoint(ElementType::kAuxiliaryObjective)) {
911 continue;
912 }
913 lin_con_updates[aux_obj].push_back(
914 {new_var.value(),
915 elemental.GetAttr(DoubleAttr2::kAuxObjLinCoef, aux_lin_obj_var)});
916 }
917 }
918 for (auto& [aux_obj, lin_terms] : lin_con_updates) {
919 absl::c_sort(lin_terms);
920 SparseDoubleVectorProto& proto_terms =
921 *mods[aux_obj].mutable_linear_coefficients();
922 for (auto [var, coef] : lin_terms) {
923 proto_terms.add_ids(var);
924 proto_terms.add_values(coef);
925 }
926 }
927 }
928 if (deleted.empty() && new_aux_objs.empty() &&
929 result.objective_updates().empty()) {
930 return std::nullopt;
931 }
932 *result.mutable_deleted_objective_ids() = std::move(deleted);
933 for (const AuxiliaryObjectiveId id : new_aux_objs) {
934 ASSIGN_OR_RETURN(((*result.mutable_new_objectives())[id.value()]),
935 ExportAuxiliaryObjective(elemental, id, remove_names));
936 }
937 return result;
938}
939
940absl::StatusOr<std::optional<ModelUpdateProto>> ExportModelUpdateProto(
941 const Elemental& elemental, const Diff& diff, const bool remove_names) {
942 const std::array<std::vector<int64_t>, kNumElements> new_elements =
943 ElementsSinceCheckpointPerType(elemental, diff);
944 if (ModelUpdateIsEmpty(diff, new_elements)) {
945 return std::nullopt;
946 }
947 // Warning: further validation is required, see comments on
948 // ValidateModelUpdateFitsInProto().
949 RETURN_IF_ERROR(ValidateModelUpdateFitsInProto(diff, new_elements));
950
951 ModelUpdateProto result;
952 const int64_t var_checkpoint = diff.checkpoint(ElementType::kVariable);
954 &new_elements[static_cast<int>(ElementType::kVariable)]);
956 &new_elements[static_cast<int>(ElementType::kLinearConstraint)]);
958 &new_elements[static_cast<int>(ElementType::kQuadraticConstraint)]);
960 &new_elements[static_cast<int>(ElementType::kIndicatorConstraint)]);
962 &new_elements[static_cast<int>(ElementType::kAuxiliaryObjective)]);
963
964 // Variables
965 *result.mutable_deleted_variable_ids() =
966 DeletedIdsSorted<ElementType::kVariable>(diff);
967 if (auto var_updates = ExportVariableUpdates(elemental, diff);
968 var_updates.has_value()) {
969 *result.mutable_variable_updates() = *std::move(var_updates);
970 }
971 if (auto vars = ExportVariables(elemental, new_var_ids, remove_names);
972 vars.has_value()) {
973 *result.mutable_new_variables() = *std::move(vars);
974 }
975
976 // Objective
977 {
978 ASSIGN_OR_RETURN(std::optional<ObjectiveUpdatesProto> objective_updates,
979 ExportObjectiveUpdates(elemental, diff, new_var_ids));
980 if (objective_updates.has_value()) {
981 *result.mutable_objective_updates() = *std::move(objective_updates);
982 }
983 }
984 // Auxiliary objectives
985 {
987 std::optional<AuxiliaryObjectivesUpdatesProto> aux_objs,
988 ExportAuxiliaryObjectivesUpdates(elemental, diff, new_var_ids,
989 new_aux_objs, remove_names));
990 if (aux_objs.has_value()) {
991 *result.mutable_auxiliary_objectives_updates() = *std::move(aux_objs);
992 }
993 }
994
995 // Linear constraints
996 *result.mutable_deleted_linear_constraint_ids() =
997 DeletedIdsSorted<ElementType::kLinearConstraint>(diff);
998 if (auto lin_con_updates = ExportLinearConstraintUpdates(elemental, diff);
999 lin_con_updates.has_value()) {
1000 *result.mutable_linear_constraint_updates() = *std::move(lin_con_updates);
1001 }
1002 if (auto lin_cons =
1003 ExportLinearConstraints(elemental, new_lin_cons, remove_names);
1004 lin_cons.has_value()) {
1005 *result.mutable_new_linear_constraints() = *std::move(lin_cons);
1006 }
1007
1008 // Linear constraint matrix
1009 {
1010 std::vector<AttrKey<2>> mat_keys =
1011 elemental.ModifiedKeysThatExist(DoubleAttr2::kLinConCoef, diff);
1012 for (const VariableId new_var : new_var_ids) {
1013 for (AttrKey<2> related_key :
1014 elemental.Slice<1>(DoubleAttr2::kLinConCoef, new_var.value())) {
1015 mat_keys.push_back(related_key);
1016 }
1017 }
1018 for (const LinearConstraintId new_con : new_lin_cons) {
1019 for (AttrKey<2> related_key :
1020 elemental.Slice<0>(DoubleAttr2::kLinConCoef, new_con.value())) {
1021 // When related_var >= checkpoint, we got this case from the loop above.
1022 // We do at most twice as much work here as should from this.
1023 if (related_key[1] < var_checkpoint) {
1024 mat_keys.push_back(related_key);
1025 }
1026 }
1027 }
1028 RETURN_IF_ERROR(CanExportToProto(mat_keys.size()))
1029 << "too many linear constraint matrix nonzeros in model update";
1030 absl::c_sort(mat_keys);
1031 if (auto mat = ExportSparseDoubleMatrix(elemental, DoubleAttr2::kLinConCoef,
1032 mat_keys);
1033 mat.has_value()) {
1034 *result.mutable_linear_constraint_matrix_updates() = *std::move(mat);
1035 }
1036 }
1037 // Quadratic constraints
1038 {
1040 std::optional<QuadraticConstraintUpdatesProto> quad_updates,
1041 ExportQuadraticConstraintsUpdates(elemental, diff, new_quad_cons,
1042 remove_names));
1043 if (quad_updates.has_value()) {
1044 *result.mutable_quadratic_constraint_updates() = *std::move(quad_updates);
1045 }
1046 }
1047 // Indicator constraints
1048 {
1049 ASSIGN_OR_RETURN(std::optional<IndicatorConstraintUpdatesProto> ind_updates,
1050 ExportIndicatorConstraintsUpdates(
1051 elemental, diff, new_ind_cons, remove_names));
1052 if (ind_updates.has_value()) {
1053 *result.mutable_indicator_constraint_updates() = *std::move(ind_updates);
1054 }
1055 }
1056 return result;
1057}
1058
1059} // namespace
1060
1061absl::StatusOr<std::optional<ModelUpdateProto>> Elemental::ExportModelUpdate(
1062 const DiffHandle diff, const bool remove_names) const {
1063 if (&diff.diffs_ != diffs_.get()) {
1065 << "diff with id: " << diff.id() << " is from another Elemental";
1066 }
1067 Diff* diff_value = diffs_->Get(diff.id());
1068 if (diff_value == nullptr) {
1070 << "Model has no diff with id: " << diff.id();
1071 }
1072 // It intentional that that this function is implemented without access to the
1073 // private API of elemental. This allows us to change the implementation
1074 // elemental without breaking the proto export code.
1075 return ExportModelUpdateProto(*this, *diff_value, remove_names);
1076}
1077
1078} // namespace operations_research::math_opt
#define ASSIGN_OR_RETURN(lhs, rexpr)
#define RETURN_IF_ERROR(expr)
const absl::flat_hash_set< int64_t > & deleted_elements(const ElementType e) const
Definition diff.h:76
absl::StatusOr< ModelProto > ExportModel(bool remove_names=false) const
absl::StatusOr< std::optional< ModelUpdateProto > > ExportModelUpdate(DiffHandle diff, bool remove_names=false) const
::google::protobuf::RepeatedField<::int64_t > *PROTOBUF_NONNULL mutable_deleted_constraint_ids()
::google::protobuf::RepeatedField<::int64_t > *PROTOBUF_NONNULL mutable_ids()
Definition model.pb.h:3189
::google::protobuf::RepeatedField<::int64_t > *PROTOBUF_NONNULL mutable_ids()
::google::protobuf::RepeatedField<::int64_t > *PROTOBUF_NONNULL mutable_row_ids()
::google::protobuf::RepeatedField<::int64_t > *PROTOBUF_NONNULL mutable_ids()
::google::protobuf::RepeatedField<::int64_t > *PROTOBUF_NONNULL mutable_ids()
Definition model.pb.h:2608
An object oriented wrapper for quadratic constraints in ModelStorage.
Definition gurobi_isv.cc:28
BoolAttr1TypeDescriptor::AttrType BoolAttr1
Definition attributes.h:300
AttrKey< AttrTypeDescriptorT< AttrType >::kNumKeyElements, typename AttrTypeDescriptorT< AttrType >::Symmetry > AttrKeyFor
The type of the AttrKey for attribute type AttrType.
ElementId< ElementType::kAuxiliaryObjective > AuxiliaryObjectiveId
Definition elements.h:266
ElementId< ElementType::kVariable > VariableId
Definition elements.h:264
AttrKey(Ints... dims) -> AttrKey< sizeof...(Ints), NoSymmetry >
CTAD for AttrKey(1,2).
DoubleAttr1TypeDescriptor::AttrType DoubleAttr1
Definition attributes.h:304
typename ElementIdsVector< element_type >::View ElementIdsSpan
Definition elements.h:252
ElementId< ElementType::kQuadraticConstraint > QuadraticConstraintId
Definition elements.h:267
DoubleAttr2TypeDescriptor::AttrType DoubleAttr2
Definition attributes.h:305
ElementId< ElementType::kLinearConstraint > LinearConstraintId
Definition elements.h:265
constexpr decltype(auto) ForEachIndex(Fn &&fn)
Definition arrays.h:56
ElementId< ElementType::kIndicatorConstraint > IndicatorConstraintId
Definition elements.h:268
StatusBuilder InvalidArgumentErrorBuilder()