25#include "absl/log/check.h"
26#include "absl/log/log.h"
27#include "absl/status/status.h"
28#include "absl/status/statusor.h"
29#include "absl/strings/string_view.h"
44 for (
const auto a : Descriptor::Enumerate()) {
45 const int index =
static_cast<int>(a);
46 attrs_[a] = StorageForAttrType<attr_type_index>(
47 Descriptor::kAttrDescriptors[index].default_value);
53 const int64_t
id)
const {
54 if (diffs_->Get(
id) ==
nullptr) {
61 auto diff = std::make_unique<Diff>();
62 diff->Advance(CurrentCheckpoint());
63 const int64_t diff_id = diffs_->Insert(std::move(diff));
68 if (&diff.diffs_ != diffs_.get()) {
71 return diffs_->Erase(diff.diff_id_);
75 if (diffs_.get() != &diff.diffs_) {
78 Diff* d = diffs_->UpdateAndGet(diff.diff_id_);
82 d->
Advance(CurrentCheckpoint());
87 if (!mutable_element_storage(e).Erase(
id)) {
90 for (
auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
91 diff->DeleteElement(e,
id);
99 UpdateAttrOnElementDeleted<AttrType, i>(a,
id);
106 const auto keys = element_ref_trackers_[a].GetKeysReferencing(
108 for (
const auto key : keys) {
112 for (
auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
113 diff->EraseKeysForAttr(a, {key});
115 attrs_[a].Erase(key);
124template <
typename AttrType,
int i>
125void Elemental::UpdateAttrOnElementDeleted(
const AttrType a,
const int64_t
id) {
126 auto& attr_storage = attrs_[a];
133 for (
auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
134 diff->EraseKeysForAttr(a, {
AttrKey(
id)});
136 attr_storage.Erase(
AttrKey(
id));
142 const std::vector<AttrKeyFor<AttrType>> keys =
144 for (
auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
145 diff->EraseKeysForAttr(a, keys);
147 for (
const auto& key : keys) {
148 attr_storage.Erase(key);
153std::array<int64_t, kNumElements> Elemental::CurrentCheckpoint()
const {
154 std::array<int64_t, kNumElements> result;
156 result[
i] = elements_[
i].next_id();
162 std::optional<absl::string_view> new_model_name)
const {
163 Elemental result(std::string(new_model_name.value_or(model_name_)),
164 primary_objective_name_);
165 result.elements_ = elements_;
166 result.attrs_ = attrs_;
void Advance(const std::array< int64_t, kNumElements > &checkpoints)
Elemental(std::string model_name="", std::string primary_objective_name="")
bool Advance(DiffHandle diff)
std::optional< DiffHandle > GetDiffHandle(int64_t id) const
const std::string & primary_objective_name() const
std::string DebugString(bool print_diffs=true) const
const std::string & model_name() const
bool DeleteElementUntyped(ElementType e, int64_t id)
Policy::template Wrapped< std::vector< AttrKeyFor< AttrType > > > Slice(AttrType a, int64_t key_elem) const
Elemental Clone(std::optional< absl::string_view > new_model_name=std::nullopt) const
bool DeleteDiff(DiffHandle diff)
static constexpr bool is_element_id_v
AttrKey(Ints... dims) -> AttrKey< sizeof...(Ints), NoSymmetry >
constexpr int kNumElements
std::ostream & operator<<(std::ostream &ostr, const SecondOrderConeConstraint &constraint)
constexpr int GetAttrKeySize()
constexpr decltype(auto) ForEachIndex(Fn &&fn)
constexpr std::array< ElementType, GetAttrKeySize< attr >()> GetElementTypes()
typename AttrTypeDescriptorT< AttrType >::ValueType ValueTypeFor
std::tuple_element_t< i, AllAttrTypeDescriptors > TypeDescriptor
static void ForEachAttr(Fn &&fn)