22#include "absl/log/check.h"
23#include "absl/types/span.h"
38 os <<
"CANNOT_PROPAGATE";
41 os <<
"CAN_PROPAGATE_ENFORCEMENT";
52 trail_(*model->GetOrCreate<
Trail>()),
59 rev_int_repository_.SetLevel(trail_.CurrentDecisionLevel());
60 rev_int_repository_.SaveStateWithStamp(&rev_stack_size_, &rev_stamp_);
63 if (literal.
Index() >=
static_cast<int>(watcher_.size()))
continue;
66 auto& watch_list = watcher_[literal.
Index()];
67 for (
const EnforcementId
id : watch_list) {
68 const LiteralIndex index = ProcessIdOnTrue(literal,
id);
71 watch_list[new_size++] = id;
74 CHECK_NE(index, literal.
Index());
75 watcher_[index].push_back(
id);
78 watch_list.resize(new_size);
81 for (
const EnforcementId
id : watcher_[literal.
NegatedIndex()]) {
85 rev_stack_size_ =
static_cast<int>(untrail_stack_.size());
89 for (
const EnforcementId
id : ids_to_fix_until_next_root_level_) {
92 if (trail_.CurrentDecisionLevel() == 0) {
93 ids_to_fix_until_next_root_level_.clear();
100 rev_int_repository_.SetLevel(trail_.CurrentDecisionLevel());
102 const int size =
static_cast<int>(untrail_stack_.size());
103 for (
int i = size - 1;
i >= rev_stack_size_; --
i) {
104 const auto [id, status] = untrail_stack_[
i];
105 statuses_[id] = status;
106 if (callbacks_[
id] !=
nullptr) callbacks_[id](id, status);
108 untrail_stack_.resize(rev_stack_size_);
118 absl::Span<const Literal> enforcement,
122 temp_literals_.clear();
123 const int level = trail_.CurrentDecisionLevel();
124 for (
const Literal l : enforcement) {
126 const int size = std::max(l.Index().value(), l.NegatedIndex().value()) + 1;
127 if (size >
static_cast<int>(watcher_.size())) {
128 watcher_.resize(size);
130 if (assignment_.LiteralIsTrue(l)) {
131 if (level == 0 || trail_.Info(l.Variable()).level == 0)
continue;
133 }
else if (assignment_.LiteralIsFalse(l)) {
136 temp_literals_.push_back(l);
141 if (temp_literals_.empty()) {
142 if (callback !=
nullptr)
144 return EnforcementId(-1);
147 const EnforcementId id(
static_cast<int>(callbacks_.size()));
148 callbacks_.push_back(std::move(callback));
150 CHECK(!temp_literals_.empty());
151 buffer_.insert(buffer_.end(), temp_literals_.begin(), temp_literals_.end());
152 starts_.push_back(buffer_.size());
155 statuses_.push_back(temp_literals_.size() == 1
159 if (temp_literals_.size() == 1) {
160 watcher_[temp_literals_[0].Index()].push_back(
id);
163 const auto span = GetSpan(
id);
164 int num_not_true = 0;
165 for (
int i = 0;
i < span.size(); ++
i) {
166 if (assignment_.LiteralIsTrue(span[
i]))
continue;
167 std::swap(span[num_not_true], span[
i]);
169 if (num_not_true == 2)
break;
173 if (num_not_true == 1) {
174 int max_level = trail_.Info(span[1].
Variable()).level;
175 for (
int i = 2;
i < span.size(); ++
i) {
176 const int level = trail_.Info(span[
i].
Variable()).level;
177 if (level > max_level) {
179 std::swap(span[1], span[
i]);
184 watcher_[span[0].Index()].push_back(
id);
185 watcher_[span[1].Index()].push_back(
id);
192 }
else if (num_true == temp_literals_.size()) {
194 }
else if (num_true + 1 == temp_literals_.size()) {
197 if (temp_literals_.size() == 1) {
198 if (callbacks_[
id] !=
nullptr) {
206 if (trail_.CurrentDecisionLevel() > 0 &&
208 ids_to_fix_until_next_root_level_.push_back(
id);
216 EnforcementId
id, std::vector<Literal>* reason)
const {
217 for (
const Literal l : GetSpan(
id)) {
218 reason->push_back(l.Negated());
222absl::Span<Literal> EnforcementPropagator::GetSpan(EnforcementId
id) {
223 if (
id < 0)
return {};
224 DCHECK_LE(
id + 1, starts_.size());
225 const int size = starts_[
id + 1] - starts_[id];
227 return absl::MakeSpan(&buffer_[starts_[
id]], size);
230absl::Span<const Literal> EnforcementPropagator::GetSpan(
231 EnforcementId
id)
const {
232 if (
id < 0)
return {};
233 DCHECK_LE(
id + 1, starts_.size());
234 const int size = starts_[
id + 1] - starts_[id];
236 return absl::MakeSpan(&buffer_[starts_[
id]], size);
239LiteralIndex EnforcementPropagator::ProcessIdOnTrue(
Literal watched,
244 const auto span = GetSpan(
id);
245 if (span.size() == 1) {
251 const int watched_pos = (span[0] == watched) ? 0 : 1;
252 CHECK_EQ(span[watched_pos], watched);
253 if (assignment_.LiteralIsFalse(span[watched_pos ^ 1])) {
258 for (
int i = 2;
i < span.size(); ++
i) {
259 const Literal l = span[
i];
260 if (assignment_.LiteralIsFalse(l)) {
264 if (!assignment_.LiteralIsAssigned(l)) {
267 std::swap(span[watched_pos], span[
i]);
268 return span[watched_pos].Index();
273 if (assignment_.LiteralIsTrue(span[watched_pos ^ 1])) {
285void EnforcementPropagator::ChangeStatus(EnforcementId
id,
288 if (old_status == new_status)
return;
289 if (trail_.CurrentDecisionLevel() != 0) {
290 untrail_stack_.push_back({id, old_status});
292 statuses_[id] = new_status;
293 if (callbacks_[
id] !=
nullptr) callbacks_[id](id, new_status);
297 absl::Span<const Literal> enforcement)
const {
299 for (
const Literal l : enforcement) {
300 if (assignment_.LiteralIsFalse(l)) {
303 if (assignment_.LiteralIsTrue(l)) ++num_true;
305 const int size = enforcement.size();
313 return Status(GetSpan(
id));
EnforcementPropagator(Model *model)
bool Propagate(Trail *trail) final
EnforcementStatus Status(EnforcementId id) const
void AddEnforcementReason(EnforcementId id, std::vector< Literal > *reason) const
void Untrail(const Trail &trail, int trail_index) final
EnforcementId Register(absl::Span< const Literal > enforcement, std::function< void(EnforcementId, EnforcementStatus)> callback=nullptr)
EnforcementStatus DebugStatus(EnforcementId id)
LiteralIndex NegatedIndex() const
LiteralIndex Index() const
SatPropagator(const std::string &name)
int propagation_trail_index_
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
const LiteralIndex kNoLiteralIndex(-1)
std::ostream & operator<<(std::ostream &os, const BoolVar &var)
@ CAN_PROPAGATE_ENFORCEMENT