Skip to content

Commit 48ae8b4

Browse files
committed
c++ for warmstart rehline solver
1 parent 493455e commit 48ae8b4

File tree

3 files changed

+67
-3
lines changed

3 files changed

+67
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rehline"
3-
version = "0.0.5"
3+
version = "0.0.6"
44
description = "Regularized Composite ReLU-ReHU Loss Minimization with Linear Computation and Linear Convergence"
55
authors = [
66
{name = "Ben Dai", email = "[email protected]"},

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pybind11.setup_helpers import Pybind11Extension, build_ext
77
from setuptools import setup
88

9-
__version__ = "0.0.5"
9+
__version__ = "0.0.6"
1010

1111
# The main interface is through Pybind11Extension.
1212
# * You can add cxx_std=11/14/17, and then build_ext can be removed.
@@ -77,3 +77,6 @@ def __str__(self) -> str:
7777
zip_safe=False,
7878
python_requires=">= 3.10",
7979
)
80+
81+
## build .so file
82+
## $ c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) ./src/rehline.cpp -o _internal$(python3-config --extension-suffix)

src/rehline.h

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,57 @@ class ReHLineSolver
599599
set_primal();
600600
}
601601

602+
603+
// Warm start: set dual variables to be the given ones
604+
inline void warmstart_params(ConstRefVec xi_ws, ConstRefMat Lambda_ws, ConstRefMat Gamma_ws)
605+
{
606+
// Warmstart parameters
607+
if (m_K > 0)
608+
{
609+
// Check shape of warmstart parameters
610+
if (xi_ws.size() != m_K) {
611+
throw std::invalid_argument("xi_ws must have size K");
612+
}
613+
// Check values of warmstart parameters
614+
if ((xi_ws.array() < 0).any()) {
615+
throw std::invalid_argument("xi_ws must be non-negative");
616+
}
617+
m_xi = xi_ws;
618+
}
619+
620+
621+
if (m_L > 0)
622+
{
623+
// Check shape of warmstart parameters
624+
if (Lambda_ws.rows() != m_L || Lambda_ws.cols() != m_n) {
625+
throw std::invalid_argument("Lambda_ws must have shape (L, n)");
626+
}
627+
// Check values of warmstart parameters
628+
if ((Lambda_ws.array() < 0).any() || (Lambda_ws.array() > 1).any()) {
629+
throw std::invalid_argument("Lambda_ws must be in [0, 1]");
630+
}
631+
m_Lambda = Lambda_ws;
632+
}
633+
634+
635+
if (m_H > 0)
636+
{
637+
// Check shape of warmstart parameters
638+
if (Gamma_ws.rows() != m_H || Gamma_ws.cols() != m_n) {
639+
throw std::invalid_argument("Gamma_ws must have shape (H, n)");
640+
}
641+
// Check values of warmstart parameters
642+
if ((Gamma_ws.array() < 0).any() || (Gamma_ws.array() > m_Tau.array()).any()) {
643+
throw std::invalid_argument("Gamma_ws must be in [0, tau_hi]");
644+
}
645+
m_Gamma = Gamma_ws;
646+
}
647+
648+
// Set primal variable based on duals
649+
set_primal();
650+
}
651+
652+
602653
inline void set_seed(Index seed) { m_rng.seed(seed); }
603654

604655
inline Index solve_vanilla(
@@ -767,7 +818,17 @@ void rehline_solver(
767818
ReHLineSolver<typename DerivedMat::PlainObject, Index> solver(X, U, V, S, T, Tau, A, b);
768819

769820
// Initialize parameters
770-
solver.init_params();
821+
try {
822+
// Warm start parameters: if result contains warm start parameters then warm start
823+
if (result.xi.size() > 0 || result.Lambda.size() > 0 || result.Gamma.size() > 0) {
824+
solver.warmstart_params(result.xi, result.Lambda, result.Gamma);
825+
} else {
826+
solver.init_params();
827+
}
828+
} catch (const std::exception& e) {
829+
std::cerr << "Warning: warmstart_params failed, using default initialization. Error: " << e.what() << std::endl;
830+
solver.init_params();
831+
}
771832

772833
// Main iterations
773834
std::vector<typename DerivedMat::Scalar> dual_objfns;

0 commit comments

Comments
 (0)