21#include "absl/algorithm/container.h"
22#include "absl/container/flat_hash_map.h"
23#include "absl/container/flat_hash_set.h"
24#include "absl/meta/type_traits.h"
25#include "absl/types/span.h"
33 auto related_vars = related_variables_.find(variable);
34 if (related_vars == related_variables_.end()) {
37 for (
const VariableId related : related_vars->second) {
38 auto mat_value = values_.find(make_key(variable, related));
39 if (mat_value != values_.end() && mat_value->second != 0.0) {
41 mat_value->second = 0.0;
49 std::vector<VariableId> result;
50 if (!related_variables_.contains(variable)) {
53 for (
const VariableId second : related_variables_.at(variable)) {
54 if (
get(variable, second) != 0) {
55 result.push_back(second);
62 std::vector<VariableId> result;
65 for (
const auto& [var, related] : related_variables_) {
67 if (
get(var, other) != 0.0) {
68 result.push_back(var);
78 std::vector<std::pair<VariableId, double>> result;
79 if (!related_variables_.contains(variable)) {
82 for (
const VariableId second : related_variables_.at(variable)) {
83 double val =
get(variable, second);
85 result.push_back({second, val});
91std::vector<std::tuple<VariableId, VariableId, double>>
93 std::vector<std::tuple<VariableId, VariableId, double>> result;
94 result.reserve(nonzeros_);
95 for (
const auto& [var_pair, value] : values_) {
97 result.push_back({var_pair.first, var_pair.second, value});
103void SparseSymmetricMatrix::CompactIfNeeded() {
104 const int64_t zeros = values_.size() - nonzeros_;
105 if (values_.empty() ||
109 for (
auto related_var_it = related_variables_.begin();
110 related_var_it != related_variables_.end();) {
112 std::vector<VariableId>& related = related_var_it->second;
114 for (int64_t read = 0; read < related.size(); ++read) {
115 auto val = values_.find(make_key(v, related[read]));
116 if (val != values_.end()) {
117 if (val->second != 0) {
118 related[write] = related[read];
126 related_variables_.erase(related_var_it++);
128 related.resize(write);
135 related_variables_.clear();
143 std::vector<VariableId> vars_in_order;
144 for (
const auto& [v, _] : related_variables_) {
145 vars_in_order.push_back(v);
147 absl::c_sort(vars_in_order);
152 std::vector<std::pair<VariableId, double>> related =
Terms(v);
153 absl::c_sort(related);
154 for (
const auto& [other, coef] : related) {
166 const absl::flat_hash_set<VariableId>& deleted_variables,
167 const absl::Span<const VariableId> new_variables,
168 const absl::flat_hash_set<std::pair<VariableId, VariableId>>& dirty)
const {
169 std::vector<std::tuple<VariableId, VariableId, double>> updates;
170 for (
const std::pair<VariableId, VariableId>& pair : dirty) {
172 if (deleted_variables.contains(pair.first) ||
173 deleted_variables.contains(pair.second)) {
176 updates.push_back({pair.first, pair.second,
get(pair.first, pair.second)});
180 if (related_variables_.contains(v)) {
182 for (
const auto& [other, coef] :
Terms(v)) {
184 updates.push_back({v, other, coef});
185 }
else if (new_variables.empty() || other < new_variables[0]) {
186 updates.push_back({other, v, coef});
void add_column_ids(::int64_t value)
void add_coefficients(double value)
void add_row_ids(::int64_t value)
SparseDoubleMatrixProto Update(const absl::flat_hash_set< VariableId > &deleted_variables, absl::Span< const VariableId > new_variables, const absl::flat_hash_set< std::pair< VariableId, VariableId > > &dirty) const
std::vector< VariableId > RelatedVariables(VariableId variable) const
std::vector< std::tuple< VariableId, VariableId, double > > Terms() const
std::vector< VariableId > Variables() const
double get(VariableId first, VariableId second) const
void Delete(VariableId variable)
SparseDoubleMatrixProto Proto() const
constexpr double kZerosCleanup
SparseDoubleMatrixProto EntriesToMatrixProto(std::vector< std::tuple< RowId, ColumnId, double > > entries)
ElementId< ElementType::kVariable > VariableId