23#include "Eigen/SparseCore"
24#include "absl/log/check.h"
25#include "absl/strings/string_view.h"
30#include "ortools/pdlp/solvers.pb.h"
39void WarnIfMatrixUnbalanced(
40 const Eigen::SparseMatrix<double, Eigen::ColMajor, int64_t>& matrix,
41 absl::string_view matrix_name, int64_t density_limit,
42 operations_research::SolverLogger* logger) {
43 if (matrix.cols() == 0)
return;
44 int64_t worst_column = 0;
45 for (int64_t col = 0; col < matrix.cols(); ++col) {
46 if (matrix.col(col).nonZeros() > matrix.col(worst_column).nonZeros()) {
50 if (matrix.col(worst_column).nonZeros() > density_limit) {
55 logger,
"WARNING: The ", matrix_name,
" has ",
56 matrix.col(worst_column).nonZeros(),
" non-zeros in its ",
58 "th column. For best parallelization all rows and columns should "
59 "have at most order ",
61 " non-zeros. Consider rewriting the QP to split the corresponding "
62 "variable or constraint.");
65 <<
"The " << matrix_name <<
" has "
66 << matrix.col(worst_column).nonZeros() <<
" non-zeros in its "
68 <<
"th column. For best parallelization all rows and columns should "
71 <<
" non-zeros. Consider rewriting the QP to split the corresponding "
72 "variable or constraint.";
83 transposed_constraint_matrix_(qp_.constraint_matrix.transpose()),
84 scheduler_(num_threads == 1 ? nullptr
86 constraint_matrix_sharder_(qp_.constraint_matrix, num_shards,
88 transposed_constraint_matrix_sharder_(transposed_constraint_matrix_,
89 num_shards, scheduler_.get()),
90 primal_sharder_(qp_.variable_lower_bounds.size(), num_shards,
92 dual_sharder_(qp_.constraint_lower_bounds.size(), num_shards,
94 CHECK_GE(num_threads, 1);
95 CHECK_GE(num_shards, num_threads);
96 if (num_threads > 1) {
97 const int64_t work_per_iteration = qp_.constraint_matrix.nonZeros() +
98 qp_.variable_lower_bounds.size() +
99 qp_.constraint_lower_bounds.size();
100 const int64_t column_density_limit = work_per_iteration / num_threads;
101 WarnIfMatrixUnbalanced(qp_.constraint_matrix,
"constraint matrix",
102 column_density_limit, logger);
103 WarnIfMatrixUnbalanced(transposed_constraint_matrix_,
104 "transposed constraint matrix", column_density_limit,
115 const Eigen::VectorXd& col_scaling_vec,
116 const Eigen::VectorXd& row_scaling_vec,
const Sharder& sharder,
117 Eigen::SparseMatrix<double, Eigen::ColMajor, int64_t>& matrix) {
118 CHECK_EQ(matrix.cols(), col_scaling_vec.size());
119 CHECK_EQ(matrix.rows(), row_scaling_vec.size());
121 auto matrix_shard = shard(matrix);
122 auto col_scaling_vec_shard = shard(col_scaling_vec);
123 for (int64_t col_num = 0; col_num < shard(matrix).outerSize(); ++col_num) {
124 for (
decltype(matrix_shard)::InnerIterator it(matrix_shard, col_num); it;
127 row_scaling_vec[it.row()] * col_scaling_vec_shard[it.col()];
133void ReplaceLargeValuesWithInfinity(
const double threshold,
135 Eigen::VectorXd& vector) {
136 constexpr double kInfinity = std::numeric_limits<double>::infinity();
138 auto vector_shard = shard(vector);
139 for (int64_t i = 0;
i < vector_shard.size(); ++
i) {
140 if (vector_shard[i] <= -threshold) vector_shard[
i] = -
kInfinity;
141 if (vector_shard[i] >= threshold) vector_shard[
i] =
kInfinity;
149 const Eigen::VectorXd& col_scaling_vec,
150 const Eigen::VectorXd& row_scaling_vec) {
151 CHECK_EQ(
PrimalSize(), col_scaling_vec.size());
152 CHECK_EQ(
DualSize(), row_scaling_vec.size());
153 primal_sharder_.ParallelForEachShard([&](
const Sharder::Shard& shard) {
154 CHECK((shard(col_scaling_vec).array() > 0.0).all());
155 shard(qp_.objective_vector) =
156 shard(qp_.objective_vector).cwiseProduct(shard(col_scaling_vec));
157 shard(qp_.variable_lower_bounds) =
158 shard(qp_.variable_lower_bounds).cwiseQuotient(shard(col_scaling_vec));
159 shard(qp_.variable_upper_bounds) =
160 shard(qp_.variable_upper_bounds).cwiseQuotient(shard(col_scaling_vec));
162 shard(qp_.objective_matrix->diagonal()) =
163 shard(qp_.objective_matrix->diagonal())
165 shard(col_scaling_vec).cwiseProduct(shard(col_scaling_vec)));
168 dual_sharder_.ParallelForEachShard([&](
const Sharder::Shard& shard) {
169 CHECK((shard(row_scaling_vec).array() > 0.0).all());
170 shard(qp_.constraint_lower_bounds) =
171 shard(qp_.constraint_lower_bounds).cwiseProduct(shard(row_scaling_vec));
172 shard(qp_.constraint_upper_bounds) =
173 shard(qp_.constraint_upper_bounds).cwiseProduct(shard(row_scaling_vec));
176 ScaleMatrix(col_scaling_vec, row_scaling_vec, constraint_matrix_sharder_,
177 qp_.constraint_matrix);
178 ScaleMatrix(row_scaling_vec, col_scaling_vec,
179 transposed_constraint_matrix_sharder_,
180 transposed_constraint_matrix_);
184 const double threshold) {
185 ReplaceLargeValuesWithInfinity(threshold,
DualSharder(),
186 qp_.constraint_lower_bounds);
187 ReplaceLargeValuesWithInfinity(threshold,
DualSharder(),
188 qp_.constraint_upper_bounds);
192 int64_t constraint_index, std::optional<double> lower_bound,
193 std::optional<double> upper_bound) {
194 CHECK_LT(constraint_index,
DualSize());
195 CHECK_GE(constraint_index, 0);
196 if (lower_bound.has_value()) {
197 qp_.constraint_lower_bounds[constraint_index] = *lower_bound;
199 if (upper_bound.has_value()) {
200 qp_.constraint_upper_bounds[constraint_index] = *upper_bound;
int64_t PrimalSize() const
void SetConstraintBounds(int64_t constraint_index, std::optional< double > lower_bound, std::optional< double > upper_bound)
const Sharder & DualSharder() const
Returns a Sharder intended for dual vectors.
void ReplaceLargeConstraintBoundsWithInfinity(double threshold)
ShardedQuadraticProgram(QuadraticProgram qp, int num_threads, int num_shards, SchedulerType scheduler_type=SCHEDULER_TYPE_GOOGLE_THREADPOOL, operations_research::SolverLogger *logger=nullptr)
void RescaleQuadraticProgram(const Eigen::VectorXd &col_scaling_vec, const Eigen::VectorXd &row_scaling_vec)
void ParallelForEachShard(const std::function< void(const Shard &)> &func) const
Runs func on each of the shards.
Validation utilities for solvers.proto.
std::unique_ptr< Scheduler > MakeScheduler(SchedulerType type, int num_threads)
Convenience factory function.
constexpr double kInfinity
bool IsLinearProgram(const QuadraticProgram &qp)
#define SOLVER_LOG(logger,...)