14#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_
15#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_
25#include "absl/container/flat_hash_map.h"
26#include "absl/container/flat_hash_set.h"
46 size_t size()
const {
return key_set_.size(); }
51 for (
const Key& key : key_set_) {
59 void Insert(
const Key& key) { key_set_.push_back(key); }
61 auto begin()
const {
return key_set_.begin(); }
62 auto end()
const {
return key_set_.end(); }
65 std::vector<Key> key_set_;
78 : key_set_(dense_set.
begin(), dense_set.
end()) {}
80 size_t size()
const {
return key_set_.size(); }
85 for (
const Key& key : key_set_) {
90 void Erase(
const Key& key) { key_set_.erase(key); }
91 void Insert(
const Key& key) { key_set_.insert(key); }
94 absl::flat_hash_set<Key> key_set_;
106 return std::visit([](
const auto& impl) {
return impl.size(); }, impl_);
110 template <
typename F>
114 [f = std::move(f)](
const auto& impl) {
115 return impl.ForEach(std::move(f));
120 auto Erase(
const Key& key) {
return AsSparse().Erase(key); }
123 std::visit([&](
auto& impl) { impl.Insert(key); }, impl_);
132 impl_ = SparseKeySet<n>(std::get<DenseKeySet<n>>(impl_));
133 return std::get<SparseKeySet<n>>(impl_);
136 std::variant<DenseKeySet<n>, SparseKeySet<n>> impl_;
141template <
int n,
typename Symmetry,
typename =
void>
148 ForEachDimension([
this, key]<
int i>() {
149 if (MustInsertNondefault<i>(key, Symmetry{})) {
150 key_nondefaults_[i][key[i]].Insert(key.template RemoveElement<i>());
158 ForEachDimension([
this, key]<
int i>() {
159 const auto& key_elem = key[i];
160 auto& nondefaults = key_nondefaults_[i];
161 if (nondefaults[key_elem].size() == 1) {
162 nondefaults.erase(key_elem);
164 nondefaults[key_elem].Erase(key.template RemoveElement<i>());
170 for (
auto& key_nondefaults : key_nondefaults_) {
171 key_nondefaults.clear();
176 std::vector<Key>
Slice(
const int64_t key_elem)
const {
178 key_elem, Symmetry{},
180 [key_elem]<
int... is>(KeySetExpansion<is>... expansions) {
181 std::vector<Key> slice((expansions.key_set.size() + ...));
182 Key* out = slice.data();
184 const auto append = [key_elem, &out]<
int j>(
185 const KeySetExpansion<j>& expansion) {
186 expansion.key_set.ForEach(
188 *out = other_elems.template AddElement<j, Symmetry>(key_elem);
192 (append(expansions), ...);
199 return SliceImpl<i>(key_elem, Symmetry{}, [](
const auto... expansions) {
200 return (expansions.key_set.size() + ...);
207 using NonDefaultKeySet =
KeySet<n - 1>;
209 using NonDefaultKeySetById = absl::flat_hash_map<int64_t, NonDefaultKeySet>;
211 using NonDefaultsPerDimension = std::array<NonDefaultKeySetById, n>;
216 struct KeySetExpansion {
217 KeySetExpansion(
const NonDefaultsPerDimension& key_nondefaults,
219 : key_set(
gtl::FindWithDefault(key_nondefaults[i], key_elem)) {}
220 const NonDefaultKeySet& key_set;
225 template <
typename F,
int i = n - 1>
226 static void ForEachDimension(
const F& f) {
227 f.template operator()<i>();
228 if constexpr (i > 0) {
229 ForEachDimension<F, i - 1>(f);
234 static bool MustInsertNondefault(
const Key&, NoSymmetry) {
238 template <
int i,
int k,
int l>
239 static bool MustInsertNondefault(
const Key& key, ElementSymmetry<k, l>) {
244 if constexpr (
i == l) {
245 const bool is_diagonal = key[k] == key[l];
253 template <
int i,
typename Fn>
254 auto SliceImpl(
const int64_t key_elem, NoSymmetry,
const Fn& fn)
const {
255 static_assert(n > 1);
256 return fn(KeySetExpansion<i>(key_nondefaults_, key_elem));
259 template <
int i,
int k,
int l,
typename Fn>
260 auto SliceImpl(
const int64_t key_elem, ElementSymmetry<k, l>,
261 const Fn& fn)
const {
262 static_assert(n > 1);
263 if constexpr (
i != k &&
i != l) {
265 return SliceImpl<i>(key_elem, NoSymmetry(), fn);
269 return fn(KeySetExpansion<k>(key_nondefaults_, key_elem),
270 KeySetExpansion<l>(key_nondefaults_, key_elem));
274 NonDefaultsPerDimension key_nondefaults_;
278template <
int n,
typename Symmetry>
298template <
typename V,
int n,
typename Symmetry>
304 static_assert(std::is_trivially_copyable_v<V>);
311 explicit AttrStorage(
const V default_value) : default_value_(default_value) {}
321 return non_default_values_.contains(key);
326 std::optional<V>
Set(
const Key key,
const V value) {
327 bool is_default = value == default_value_;
329 const auto it = non_default_values_.find(key);
330 if (it == non_default_values_.end()) {
333 const V prev_value = it->second;
334 non_default_values_.erase(it);
335 slicing_support_.ClearRowsAndColumns(key);
338 const auto [it, inserted] = non_default_values_.try_emplace(key, value);
340 slicing_support_.AddRowsAndColumns(key);
341 return default_value_;
344 if (value == it->second) {
347 return std::exchange(it->second, value);
358 auto it = non_default_values_.find(key);
359 if (it == non_default_values_.end()) {
367 if (non_default_values_.erase(key)) {
368 slicing_support_.ClearRowsAndColumns(key);
375 std::vector<Key> result;
376 result.reserve(non_default_values_.size());
377 for (
const auto& [key, unused] : non_default_values_) {
378 result.push_back(key);
388 std::vector<Key>
Slice(
const int64_t key_elem)
const {
389 static_assert(n >= 1);
390 if constexpr (n == 1) {
391 return non_default_values_.contains(
Key(key_elem))
392 ? std::vector<Key>({
Key(key_elem)})
393 : std::vector<Key>();
395 return slicing_support_.template
Slice<i>(key_elem);
403 static_assert(n >= 1);
404 if constexpr (n == 1) {
405 return non_default_values_.count(
Key(key_elem));
417 non_default_values_.clear();
418 slicing_support_.Clear();
int64_t num_non_defaults() const
AttrStorage & operator=(const AttrStorage &)=default
std::optional< V > GetIfNonDefault(const Key key) const
Returns the value of the attribute for key, or nullopt.
std::vector< Key > NonDefaults() const
AttrStorage(const V default_value)
std::optional< V > Set(const Key key, const V value)
AttrStorage()
Generally avoid, provided to make working with std::array easier.
AttrKey< n, Symmetry > Key
AttrStorage(AttrStorage &&)=default
AttrStorage(const AttrStorage &)=default
int64_t GetSliceSize(const int64_t key_elem) const
void Erase(const Key key)
Sets the value of the attribute for key to the default value.
AttrStorage & operator=(AttrStorage &&)=default
bool IsNonDefault(const Key key) const
void Clear()
Restore all elements to their default value for this attribute.
std::vector< Key > Slice(const int64_t key_elem) const
V Get(const Key key) const
void Insert(const Key &key)
AttrKey< n, NoSymmetry > Key
void ForEach(F f) const
requires std::invocable<F, const Key&>
void ForEach(F f) const
We can't do begin/end because the iterator types are not the same.
AttrKey< n, NoSymmetry > Key
auto Erase(const Key &key)
void Insert(const Key &key)
std::vector< Key > Slice(const int64_t key_elem) const
AttrKey< n, Symmetry > Key
int64_t GetSliceSize(const int64_t key_elem) const
void ClearRowsAndColumns(Key key)
Requires key is currently stored with a non-default value.
void AddRowsAndColumns(const Key key)
SparseKeySet(const DenseKeySet< n > &dense_set)
void Insert(const Key &key)
void ForEach(F f) const
requires std::invocable<F, const Key&>
void Erase(const Key &key)
AttrKey< n, NoSymmetry > Key
An object oriented wrapper for quadratic constraints in ModelStorage.
std::conditional_t<(AttrKeyT::size() > 0), absl::flat_hash_map< AttrKeyT, V >, detail::AttrKey0RawSet< typename AttrKeyT::SymmetryT, std::pair< AttrKeyT, V > > > AttrKeyHashMap
ClosedInterval::Iterator end(ClosedInterval interval)
ClosedInterval::Iterator begin(ClosedInterval interval)
void AddRowsAndColumns(Key)
void ClearRowsAndColumns(Key)
AttrKey< n, Symmetry > Key