Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions src/callback.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
//! # User Callbacks in HiGHS

use std::ffi::{c_int, c_void};

/// User callbacks while solving
pub trait Callback {
/// The main callback routine
fn callback(&mut self, context: CallbackOuterContext<'_>) -> CallbackReturn;
}

/// The context of a user callback
pub struct CallbackOuterContext<'a> {
data: &'a highs_sys::HighsCallbackDataOut,
callback_type: c_int,
}

// Applicable in all contexts
impl<'a> CallbackOuterContext<'a> {
/// Gets the inner callback context
pub fn inner(self) -> CallbackType<'a> {
match self.callback_type {
highs_sys::kHighsCallbackLogging => CallbackType::Logging(CallbackContext {
data: self.data,
_ctx: CbCtxLogging,
}),
highs_sys::kHighsCallbackSimplexInterrupt => {
CallbackType::SimplexInterrupt(CallbackContext {
data: self.data,
_ctx: CbCtxSimplexInterrupt,
})
}
highs_sys::kHighsCallbackIpmInterrupt => CallbackType::IpmInterrupt(CallbackContext {
data: self.data,
_ctx: CbCtxIpmInterrupt,
}),
highs_sys::kHighsCallbackMipSolution => CallbackType::MipSolution(CallbackContext {
data: self.data,
_ctx: CbCtxMipSolution,
}),
highs_sys::kHighsCallbackMipImprovingSolution => {
CallbackType::MipImprovingSolution(CallbackContext {
data: self.data,
_ctx: CbCtxMipImprovingSolution,
})
}
highs_sys::kHighsCallbackMipLogging => CallbackType::MipLogging(CallbackContext {
data: self.data,
_ctx: CbCtxMipLogging,
}),
highs_sys::kHighsCallbackMipInterrupt => CallbackType::MipInterrupt(CallbackContext {
data: self.data,
_ctx: CbCtxMipInterrupt,
}),
highs_sys::kHighsCallbackMipGetCutPool => {
CallbackType::MipGetCutPool(CallbackContext {
data: self.data,
_ctx: CbCtxMipGetCutPool,
})
}
highs_sys::kHighsCallbackMipDefineLazyConstraints => {
CallbackType::MipDefineLazyConstraints(CallbackContext {
data: self.data,
_ctx: CbCtxMipDefineLazyConstraints,
})
}
_ => unreachable!(),
}
}

/// Gets the running time of the solver
pub fn get_running_time(&self) -> f64 {
self.data.running_time
}
}

/// The type of callback
pub enum CallbackType<'a> {
/// Logging callback
Logging(CallbackContext<'a, CbCtxLogging>),
/// Simplex interrupt callback
SimplexInterrupt(CallbackContext<'a, CbCtxSimplexInterrupt>),
/// IPM interrupt callback
IpmInterrupt(CallbackContext<'a, CbCtxIpmInterrupt>),
/// Found a MIP solution
MipSolution(CallbackContext<'a, CbCtxMipSolution>),
/// Found an improving MIP solution
MipImprovingSolution(CallbackContext<'a, CbCtxMipImprovingSolution>),
/// MIP logging callback
MipLogging(CallbackContext<'a, CbCtxMipLogging>),
/// MIP interrupt callback
MipInterrupt(CallbackContext<'a, CbCtxMipInterrupt>),
/// MIP get cut pool callback
MipGetCutPool(CallbackContext<'a, CbCtxMipGetCutPool>),
/// MIP define lazy constraints callback
MipDefineLazyConstraints(CallbackContext<'a, CbCtxMipDefineLazyConstraints>),
}

/// Logging callback context
pub struct CbCtxLogging;
/// Simplex interrupt callback context
pub struct CbCtxSimplexInterrupt;
/// IPM interrupt callback context
pub struct CbCtxIpmInterrupt;
/// MIP solution callback context
pub struct CbCtxMipSolution;
/// MIP improving solution callback context
pub struct CbCtxMipImprovingSolution;
/// MIP logging callback context
pub struct CbCtxMipLogging;
/// MIP interrupt callback context
pub struct CbCtxMipInterrupt;
/// MIP get cut pool callback context
pub struct CbCtxMipGetCutPool;
/// MIP define lazy constraints callback context
pub struct CbCtxMipDefineLazyConstraints;

/// An inner callback context
pub struct CallbackContext<'a, Ctx> {
data: &'a highs_sys::HighsCallbackDataOut,
_ctx: Ctx,
}

// Applicable in all contexts
impl<Ctx> CallbackContext<'_, Ctx> {
/// Gets the running time of the solver
pub fn get_running_time(&self) -> f64 {
self.data.running_time
}
}

/// The return type for a user callback
#[derive(Debug, Default)]
pub struct CallbackReturn {
user_interrupt: bool,
}

impl CallbackReturn {
/// Sets the user interrupt value
pub fn set_interrupt(&mut self, interrupt: bool) -> &mut Self {
self.user_interrupt = interrupt;
self
}
}

pub(crate) struct UserCallbackData<'a>(pub &'a mut dyn Callback);

pub(crate) unsafe extern "C" fn callback(
callback_type: c_int,
_message: *const i8,
out_data: *const highs_sys::HighsCallbackDataOut,
in_data: *mut highs_sys::HighsCallbackDataIn,
user_callback_data: *mut c_void,
) {
let user_callback_data = &mut *user_callback_data.cast::<UserCallbackData>();
let ctx = CallbackOuterContext {
data: &*out_data,
callback_type,
};
let res = user_callback_data.0.callback(ctx);
if res.user_interrupt {
(*in_data).user_interrupt = 1;
}
}
36 changes: 36 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,14 @@ pub type RowProblem = Problem<RowMatrix>;
/// See [`Problem<ColMatrix>`](Problem#impl).
pub type ColProblem = Problem<ColMatrix>;

pub mod callback;
mod matrix_col;
mod matrix_row;
mod options;
mod status;

pub use callback::Callback;

/// A complete optimization problem.
/// Depending on the `MATRIX` type parameter, the problem will be built
/// constraint by constraint (with [ColProblem]), or
Expand Down Expand Up @@ -375,6 +378,39 @@ impl Model {
.map(|_| SolvedModel { highs: self.highs })
}

/// Like [`Self::solve`], but with a user callback
pub fn solve_with_callback<Cb>(self, cb: &mut Cb) -> SolvedModel
where
Cb: Callback,
{
self.try_solve_with_callback(cb)
.expect("HiGHS error: invalid problem")
}

/// Like [`Self::try_solve`], but with a user callback
pub fn try_solve_with_callback<Cb>(mut self, cb: &mut Cb) -> Result<SolvedModel, HighsStatus>
where
Cb: Callback,
{
let mut user_callback_data = callback::UserCallbackData(cb);
unsafe {
highs_call!(Highs_setCallback(
self.highs.mut_ptr(),
Some(callback::callback),
(&mut user_callback_data as *mut callback::UserCallbackData).cast()
))
}?;
unsafe { highs_call!(Highs_run(self.highs.mut_ptr())) }?;
unsafe {
highs_call!(Highs_setCallback(
self.highs.mut_ptr(),
None,
std::ptr::null_mut()
))
}?;
Ok(SolvedModel { highs: self.highs })
}

/// Adds a new constraint to the highs model.
///
/// Returns the added row index.
Expand Down