24#include "absl/container/flat_hash_set.h"
25#include "absl/status/status.h"
26#include "absl/status/statusor.h"
27#include "absl/strings/str_cat.h"
28#include "absl/time/clock.h"
29#include "absl/time/time.h"
30#include "absl/types/span.h"
36#include "ortools/math_opt/callback.pb.h"
40#include "ortools/math_opt/solution.pb.h"
42#include "ortools/math_opt/sparse_containers.pb.h"
52constexpr int kNumGurobiEvents = 9;
53constexpr double kInf = std::numeric_limits<double>::infinity();
56constexpr int CheckedGuroibWhere() {
57 static_assert(
where >= 0 &&
where < kNumGurobiEvents);
61inline int GurobiEvent(CallbackEventProto event) {
63 case CALLBACK_EVENT_PRESOLVE:
64 return CheckedGuroibWhere<GRB_CB_PRESOLVE>();
65 case CALLBACK_EVENT_SIMPLEX:
66 return CheckedGuroibWhere<GRB_CB_SIMPLEX>();
67 case CALLBACK_EVENT_MIP:
68 return CheckedGuroibWhere<GRB_CB_MIP>();
69 case CALLBACK_EVENT_MIP_SOLUTION:
70 return CheckedGuroibWhere<GRB_CB_MIPSOL>();
71 case CALLBACK_EVENT_MIP_NODE:
72 return CheckedGuroibWhere<GRB_CB_MIPNODE>();
73 case CALLBACK_EVENT_BARRIER:
74 return CheckedGuroibWhere<GRB_CB_BARRIER>();
75 case CALLBACK_EVENT_UNSPECIFIED:
77 LOG(FATAL) <<
"Unexpected callback event: " << event;
81SparseDoubleVectorProto ApplyFilter(
82 const std::vector<double>& grb_solution,
84 const SparseVectorFilterProto& filter) {
85 SparseVectorFilterPredicate predicate(filter);
86 SparseDoubleVectorProto result;
87 for (
const auto [
id, grb_index] : var_ids) {
88 const double val = grb_solution[grb_index];
89 if (predicate.AcceptsAndUpdate(
id, val)) {
91 result.add_values(val);
97absl::StatusOr<int64_t> CbGetInt64(
const Gurobi::CallbackContext&
context,
100 int64_t result64 =
static_cast<int64_t
>(result);
101 if (result !=
static_cast<double>(result64)) {
102 return absl::InternalError(
103 absl::StrCat(
"Error converting double attribute: ", what,
104 "with value: ", result,
" to int64_t exactly."));
109absl::StatusOr<bool> CbGetBool(
const Gurobi::CallbackContext&
context,
112 bool result_bool =
static_cast<bool>(result);
113 if (result !=
static_cast<int>(result_bool)) {
114 return absl::InternalError(
115 absl::StrCat(
"Error converting int attribute: ", what,
116 "with value: ", result,
" to bool exactly."));
124#define MO_SET_OR_RET(setter, statusor) \
126 auto eval_status_or = statusor; \
127 RETURN_IF_ERROR(eval_status_or.status()) << __FILE__ << ":" << __LINE__; \
128 setter(*std::move(eval_status_or)); \
133absl::Status SetRuntime(
const GurobiCallbackInput& callback_input,
134 CallbackDataProto& callback_data) {
136 callback_data.mutable_runtime());
141absl::StatusOr<std::optional<CallbackDataProto>> CreateCallbackDataProto(
142 const Gurobi::CallbackContext&
c,
const GurobiCallbackInput& callback_input,
143 MessageCallbackData& message_callback_data) {
144 CallbackDataProto callback_data;
149 callback_data.set_event(CALLBACK_EVENT_PRESOLVE);
150 CallbackDataProto::PresolveStats*
const s =
151 callback_data.mutable_presolve_stats();
159 callback_data.set_event(CALLBACK_EVENT_SIMPLEX);
160 CallbackDataProto::SimplexStats*
const s =
161 callback_data.mutable_simplex_stats();
172 callback_data.set_event(CALLBACK_EVENT_BARRIER);
173 CallbackDataProto::BarrierStats*
const s =
174 callback_data.mutable_barrier_stats();
189 callback_data.set_event(CALLBACK_EVENT_MIP);
190 CallbackDataProto::MipStats*
const s = callback_data.mutable_mip_stats();
204 callback_data.set_event(CALLBACK_EVENT_MIP_SOLUTION);
205 CallbackDataProto::MipStats*
const s = callback_data.mutable_mip_stats();
211 std::vector<double> var_values(callback_input.num_gurobi_vars);
214 <<
"Error reading solution at event MIP_SOLUTION";
215 *callback_data.mutable_primal_solution_vector() =
216 ApplyFilter(var_values, callback_input.variable_ids,
217 callback_input.mip_solution_filter);
221 callback_data.set_event(CALLBACK_EVENT_MIP_NODE);
222 CallbackDataProto::MipStats*
const s = callback_data.mutable_mip_stats();
232 <<
"Error reading solution status at event MIP_NODE";
234 std::vector<double> var_values(callback_input.num_gurobi_vars);
237 <<
"Error reading solution at event MIP_NODE";
238 *callback_data.mutable_primal_solution_vector() =
239 ApplyFilter(var_values, callback_input.variable_ids,
240 callback_input.mip_node_filter);
246 LOG(FATAL) <<
"Unknown gurobi callback code " <<
c.where();
250 <<
"Error encoding runtime at callback event: " <<
c.where();
252 return callback_data;
257absl::Status ApplyResult(
const Gurobi::CallbackContext&
context,
258 const GurobiCallbackInput& callback_input,
259 const CallbackResultProto& result,
260 SolveInterrupter& local_interrupter) {
261 for (
const CallbackResultProto::GeneratedLinearConstraint& cut :
263 std::vector<int> gurobi_vars;
264 gurobi_vars.reserve(cut.linear_expression().ids_size());
265 for (
const int64_t
id : cut.linear_expression().ids()) {
266 gurobi_vars.push_back(callback_input.variable_ids.at(
id));
268 std::vector<std::pair<char, double>> sense_bound_pairs;
269 if (cut.lower_bound() == cut.upper_bound()) {
270 sense_bound_pairs.emplace_back(
GRB_EQUAL, cut.upper_bound());
272 if (cut.upper_bound() <
kInf) {
273 sense_bound_pairs.emplace_back(
GRB_LESS_EQUAL, cut.upper_bound());
275 if (cut.lower_bound() > -
kInf) {
279 for (
const auto [sense,
bound] : sense_bound_pairs) {
282 gurobi_vars, cut.linear_expression().values(), sense,
bound));
285 gurobi_vars, cut.linear_expression().values(), sense,
bound));
289 for (
const SparseDoubleVectorProto& solution_vector :
290 result.suggested_solutions()) {
293 std::vector<double> gurobi_var_values(callback_input.num_gurobi_vars,
296 gurobi_var_values[callback_input.variable_ids.at(
id)] =
value;
301 if (result.terminate()) {
302 local_interrupter.Interrupt();
303 return absl::OkStatus();
305 return absl::OkStatus();
311 const absl::flat_hash_set<CallbackEventProto>& events) {
312 std::vector<bool> result(kNumGurobiEvents);
313 for (
const auto event : events) {
314 result[GurobiEvent(event)] =
true;
332 if (local_interrupter !=
nullptr && local_interrupter->
IsInterrupted()) {
342 return absl::OkStatus();
346 const absl::StatusOr<std::string> msg =
context.CbGetMessage();
348 <<
"Error getting message string in callback";
349 const std::vector<std::string> lines = message_callback_data.
Parse(*msg);
350 if (!lines.empty()) {
354 return absl::OkStatus();
357 if (callback_input.
user_cb ==
nullptr ||
359 return absl::OkStatus();
363 CHECK(local_interrupter !=
nullptr);
366 const std::optional<CallbackDataProto> callback_data,
367 CreateCallbackDataProto(
context, callback_input, message_callback_data));
368 if (!callback_data) {
369 return absl::OkStatus();
371 const absl::StatusOr<CallbackResultProto> result =
372 callback_input.
user_cb(*callback_data);
375 return result.status();
378 ApplyResult(
context, callback_input, *result, *local_interrupter));
379 return absl::OkStatus();
384 const std::vector<std::string> lines = message_callback_data.
Flush();
#define ASSIGN_OR_RETURN(lhs, rexpr)
#define RETURN_IF_ERROR(expr)
bool IsInterrupted() const
std::vector< std::string > Parse(absl::string_view message)
std::vector< std::string > Flush()
#define GRB_CB_MIP_NODLFT
#define GRB_CB_SPX_ISPERT
#define GRB_CB_SPX_OBJVAL
#define GRB_CB_PRE_COECHG
#define GRB_GREATER_EQUAL
#define GRB_CB_MIPSOL_OBJBST
#define GRB_CB_MIP_ITRCNT
#define GRB_CB_MIPSOL_SOLCNT
#define GRB_CB_SPX_PRIMINF
#define GRB_CB_MIPNODE_REL
#define GRB_CB_MIPNODE_OBJBND
#define GRB_CB_MIPNODE_OBJBST
#define GRB_CB_SPX_ITRCNT
#define GRB_CB_BARRIER_COMPL
#define GRB_CB_MIPSOL_NODCNT
#define GRB_CB_MIPNODE_STATUS
#define GRB_CB_PRE_ROWDEL
#define GRB_CB_BARRIER_PRIMOBJ
#define GRB_CB_MIPNODE_SOLCNT
#define GRB_CB_SPX_DUALINF
#define GRB_CB_MIP_OBJBST
#define GRB_CB_BARRIER_DUALINF
#define GRB_CB_BARRIER_ITRCNT
#define GRB_CB_MIP_CUTCNT
#define GRB_CB_BARRIER_DUALOBJ
#define GRB_CB_MIP_SOLCNT
#define GRB_CB_PRE_COLDEL
#define GRB_CB_MIP_NODCNT
#define GRB_CB_MIPSOL_SOL
#define GRB_CB_MIPSOL_OBJBND
#define GRB_CB_MIP_OBJBND
#define GRB_CB_BARRIER_PRIMINF
#define GRB_CB_PRE_BNDCHG
#define GRB_CB_MIPNODE_NODCNT
#define MO_SET_OR_RET(setter, statusor)
GurobiMPCallbackContext * context
std::vector< bool > EventToGurobiWhere(const absl::flat_hash_set< CallbackEventProto > &events)
void GurobiCallbackImplFlush(const GurobiCallbackInput &callback_input, MessageCallbackData &message_callback_data)
SparseVectorView< T > MakeView(absl::Span< const int64_t > ids, const Collection &values)
absl::Status GurobiCallbackImpl(const Gurobi::CallbackContext &context, const GurobiCallbackInput &callback_input, MessageCallbackData &message_callback_data, SolveInterrupter *const local_interrupter)
In SWIG mode, we don't want anything besides these top-level includes.
inline ::absl::StatusOr< google::protobuf::Duration > EncodeGoogleApiProto(absl::Duration d)