@@ -599,6 +599,57 @@ class ReHLineSolver
599599 set_primal ();
600600 }
601601
602+
603+ // Warm start: set dual variables to be the given ones
604+ inline void warmstart_params (ConstRefVec xi_ws, ConstRefMat Lambda_ws, ConstRefMat Gamma_ws)
605+ {
606+ // Warmstart parameters
607+ if (m_K > 0 )
608+ {
609+ // Check shape of warmstart parameters
610+ if (xi_ws.size () != m_K) {
611+ throw std::invalid_argument (" xi_ws must have size K" );
612+ }
613+ // Check values of warmstart parameters
614+ if ((xi_ws.array () < 0 ).any ()) {
615+ throw std::invalid_argument (" xi_ws must be non-negative" );
616+ }
617+ m_xi = xi_ws;
618+ }
619+
620+
621+ if (m_L > 0 )
622+ {
623+ // Check shape of warmstart parameters
624+ if (Lambda_ws.rows () != m_L || Lambda_ws.cols () != m_n) {
625+ throw std::invalid_argument (" Lambda_ws must have shape (L, n)" );
626+ }
627+ // Check values of warmstart parameters
628+ if ((Lambda_ws.array () < 0 ).any () || (Lambda_ws.array () > 1 ).any ()) {
629+ throw std::invalid_argument (" Lambda_ws must be in [0, 1]" );
630+ }
631+ m_Lambda = Lambda_ws;
632+ }
633+
634+
635+ if (m_H > 0 )
636+ {
637+ // Check shape of warmstart parameters
638+ if (Gamma_ws.rows () != m_H || Gamma_ws.cols () != m_n) {
639+ throw std::invalid_argument (" Gamma_ws must have shape (H, n)" );
640+ }
641+ // Check values of warmstart parameters
642+ if ((Gamma_ws.array () < 0 ).any () || (Gamma_ws.array () > m_Tau.array ()).any ()) {
643+ throw std::invalid_argument (" Gamma_ws must be in [0, tau_hi]" );
644+ }
645+ m_Gamma = Gamma_ws;
646+ }
647+
648+ // Set primal variable based on duals
649+ set_primal ();
650+ }
651+
652+
602653 inline void set_seed (Index seed) { m_rng.seed (seed); }
603654
604655 inline Index solve_vanilla (
@@ -767,7 +818,17 @@ void rehline_solver(
767818 ReHLineSolver<typename DerivedMat::PlainObject, Index> solver (X, U, V, S, T, Tau, A, b);
768819
769820 // Initialize parameters
770- solver.init_params ();
821+ try {
822+ // Warm start parameters: if result contains warm start parameters then warm start
823+ if (result.xi .size () > 0 || result.Lambda .size () > 0 || result.Gamma .size () > 0 ) {
824+ solver.warmstart_params (result.xi , result.Lambda , result.Gamma );
825+ } else {
826+ solver.init_params ();
827+ }
828+ } catch (const std::exception& e) {
829+ std::cerr << " Warning: warmstart_params failed, using default initialization. Error: " << e.what () << std::endl;
830+ solver.init_params ();
831+ }
771832
772833 // Main iterations
773834 std::vector<typename DerivedMat::Scalar> dual_objfns;
0 commit comments