48 const SparseDoubleMatrixProto& matrix,
const int64_t start_row_id,
49 const std::optional<int64_t> end_row_id,
const int64_t start_col_id,
50 const std::optional<int64_t> end_col_id) {
51 const int matrix_size = matrix.row_ids_size();
52 CHECK_EQ(matrix_size, matrix.column_ids_size());
53 CHECK_EQ(matrix_size, matrix.coefficients_size());
54 const IndexRange row_range = {.start = start_row_id, .end = end_row_id};
55 const IndexRange col_range = {.start = start_col_id, .end = end_col_id};
60 for (
int row_start = 0, next_row_start; row_start < matrix_size;
63 row_start = next_row_start) {
66 const int64_t row_id = matrix.row_ids(row_start);
67 int row_end = row_start + 1;
68 while (row_end < matrix_size && matrix.row_ids(row_end) == row_id) {
73 next_row_start = row_end;
76 if (!row_range.Contains(row_id)) {
81 int row_cols_start = row_start;
82 while (row_cols_start < row_end &&
83 !col_range.Contains(matrix.column_ids(row_cols_start))) {
89 int row_cols_end = row_cols_start;
90 while (row_cols_end < row_end &&
91 col_range.Contains(matrix.column_ids(row_cols_end))) {
94 const int row_cols_len = row_cols_end - row_cols_start;
96 if (row_cols_len != 0) {
97 filtered_rows.emplace_back(
98 row_id,
MakeView(absl::MakeConstSpan(matrix.column_ids())
99 .subspan(row_cols_start, row_cols_len),
100 absl::MakeConstSpan(matrix.coefficients())
101 .subspan(row_cols_start, row_cols_len)));
105 return filtered_rows;
112 absl::flat_hash_map<int64_t, SparseVector<double>> filtered_columns;
113 for (
const auto& [row_id, column_values] : submatrix_by_rows) {
114 for (
const auto [column_id,
value] : column_values) {
116 row_values.
ids.push_back(row_id);
122 std::vector<std::pair<int64_t, SparseVector<double>>> sorted_filtered_columns(
123 std::make_move_iterator(filtered_columns.begin()),
124 std::make_move_iterator(filtered_columns.end()));
125 std::sort(sorted_filtered_columns.begin(), sorted_filtered_columns.end(),
128 return lhs.first < rhs.first;
131 return sorted_filtered_columns;
SparseSubmatrixRowsView SparseSubmatrixByRows(const SparseDoubleMatrixProto &matrix, const int64_t start_row_id, const std::optional< int64_t > end_row_id, const int64_t start_col_id, const std::optional< int64_t > end_col_id)