11use approx:: { relative_eq, relative_ne} ;
22use indexmap:: { IndexMap , IndexSet } ;
3- use itertools:: { Either , Itertools } ;
4- use nalgebra:: { DMatrix , DVector } ;
3+ use itertools:: { Either , Itertools , multiunzip } ;
4+ use nalgebra:: { CsMatrix , DMatrix , DVector } ;
55use serde:: { Deserialize , Serialize } ;
66
77const EPSILON : f64 = 1e-8 ;
@@ -17,7 +17,9 @@ pub struct Solver {
1717 next_constraint : ConstraintId ,
1818 constraints : IndexMap < ConstraintId , LinearExpr > ,
1919 var_to_constraints : IndexMap < Var , IndexSet < ConstraintId > > ,
20+ // Solved and unsolved vars are separate to reduce overhead of many solved variables.
2021 solved_vars : IndexMap < Var , f64 > ,
22+ unsolved_vars : IndexSet < Var > ,
2123 inconsistent_constraints : IndexSet < ConstraintId > ,
2224 invalid_rounding : IndexSet < Var > ,
2325}
@@ -33,22 +35,21 @@ impl Solver {
3335
3436 pub fn new_var ( & mut self ) -> Var {
3537 let var = Var ( self . next_var ) ;
38+ self . unsolved_vars . insert ( var) ;
3639 self . next_var += 1 ;
3740 var
3841 }
3942
4043 /// Returns true if all variables have been solved.
4144 pub fn fully_solved ( & self ) -> bool {
42- self . solved_vars . len ( ) == self . next_var as usize
45+ self . unsolved_vars . is_empty ( )
4346 }
4447
4548 pub fn force_solution ( & mut self ) {
4649 while !self . fully_solved ( ) {
4750 // Find any unsolved variable and constrain it to equal 0.
48- let v = ( 0 ..self . next_var )
49- . find ( |& i| !self . solved_vars . contains_key ( & Var ( i) ) )
50- . unwrap ( ) ;
51- self . constrain_eq0 ( LinearExpr :: from ( Var ( v) ) ) ;
51+ let v = self . unsolved_vars . first ( ) . unwrap ( ) ;
52+ self . constrain_eq0 ( LinearExpr :: from ( * v) ) ;
5253 self . solve ( ) ;
5354 }
5455 }
@@ -63,8 +64,13 @@ impl Solver {
6364 & self . invalid_rounding
6465 }
6566
66- pub fn unsolved_vars ( & self ) -> IndexSet < Var > {
67- IndexSet :: from_iter ( ( 0 ..self . next_var ) . map ( Var ) . filter ( |& v| !self . is_solved ( v) ) )
67+ pub fn unsolved_vars ( & self ) -> & IndexSet < Var > {
68+ & self . unsolved_vars
69+ }
70+
71+ pub fn solve_var ( & mut self , var : Var , val : f64 ) {
72+ self . solved_vars . insert ( var, val) ;
73+ self . unsolved_vars . swap_remove ( & var) ;
6874 }
6975
7076 /// Constrains the value of `expr` to 0.
@@ -109,7 +115,7 @@ impl Solver {
109115 if relative_ne ! ( val, rounded_val, epsilon = EPSILON ) {
110116 self . invalid_rounding . insert ( var) ;
111117 }
112- self . solved_vars . insert ( var, rounded_val) ;
118+ self . solve_var ( var, rounded_val) ;
113119 }
114120 self . constraints . swap_remove ( & constraint_id) ;
115121 for constraint in self
@@ -127,16 +133,29 @@ impl Solver {
127133
128134 /// Solves for as many variables as possible and substitutes their values into existing constraints.
129135 /// Deletes constraints that no longer contain unsolved variables.
136+ ///
137+ /// Constraints should be simplified before this function is invoked.
130138 pub fn solve ( & mut self ) {
131- let n_vars = self . next_var as usize ;
139+ // Snapshot unsolved variables before solving.
140+ let unsolved_vars = self . unsolved_vars . clone ( ) ;
141+ let n_vars = unsolved_vars. len ( ) ;
132142 if n_vars == 0 || self . constraints . is_empty ( ) {
133143 return ;
134144 }
135- let a = DMatrix :: from_row_iterator (
145+ let ( i, j, val) : ( Vec < _ > , Vec < _ > , Vec < _ > ) =
146+ multiunzip ( self . constraints . values ( ) . enumerate ( ) . flat_map ( |( i, expr) | {
147+ expr. coeffs . iter ( ) . map ( {
148+ let unsolved_vars = & unsolved_vars;
149+ move |( coeff, var) | ( i, unsolved_vars. get_index_of ( var) . unwrap ( ) , * coeff)
150+ } )
151+ } ) ) ;
152+ let a = DMatrix :: from ( CsMatrix :: from_triplet (
136153 self . constraints . len ( ) ,
137154 n_vars,
138- self . constraints . values ( ) . flat_map ( |c| c. coeff_vec ( n_vars) ) ,
139- ) ;
155+ & i,
156+ & j,
157+ & val,
158+ ) ) ;
140159 let b = DVector :: from_iterator (
141160 self . constraints . len ( ) ,
142161 self . constraints . values ( ) . map ( |c| -c. constant ) ,
@@ -151,12 +170,11 @@ impl Solver {
151170 let vt_recons = vt. rows ( 0 , r) ;
152171 let sol = svd. solve ( & b, EPSILON ) . unwrap ( ) ;
153172
154- for i in 0 ..self . next_var {
155- let recons = ( vt_recons. transpose ( ) * vt_recons. column ( i as usize ) ) [ ( ( i as usize ) , 0 ) ] ;
156- if !self . solved_vars . contains_key ( & Var ( i) )
157- && relative_eq ! ( recons, 1. , epsilon = EPSILON )
158- {
159- self . solved_vars . insert ( Var ( i) , sol[ ( i as usize , 0 ) ] ) ;
173+ for var in & unsolved_vars {
174+ let i = unsolved_vars. get_index_of ( var) . unwrap ( ) ;
175+ let recons = ( vt_recons. transpose ( ) * vt_recons. column ( i) ) [ ( i, 0 ) ] ;
176+ if relative_eq ! ( recons, 1. , epsilon = EPSILON ) {
177+ self . solve_var ( * var, sol[ ( i, 0 ) ] ) ;
160178 }
161179 }
162180 for ( id, constraint) in self . constraints . iter_mut ( ) {
@@ -209,14 +227,6 @@ pub struct LinearExpr {
209227}
210228
211229impl LinearExpr {
212- pub fn coeff_vec ( & self , n_vars : usize ) -> Vec < f64 > {
213- let mut out = vec ! [ 0. ; n_vars] ;
214- for ( val, var) in & self . coeffs {
215- out[ var. 0 as usize ] += * val;
216- }
217- out
218- }
219-
220230 pub fn add ( lhs : impl Into < LinearExpr > , rhs : impl Into < LinearExpr > ) -> Self {
221231 lhs. into ( ) + rhs. into ( )
222232 }
@@ -357,5 +367,8 @@ mod tests {
357367 assert_relative_eq ! ( * solver. solved_vars. get( & x) . unwrap( ) , 5. , epsilon = EPSILON ) ;
358368 assert_relative_eq ! ( * solver. solved_vars. get( & y) . unwrap( ) , 5. , epsilon = EPSILON ) ;
359369 assert ! ( !solver. solved_vars. contains_key( & z) ) ;
370+ assert ! ( !solver. unsolved_vars. contains( & x) ) ;
371+ assert ! ( !solver. unsolved_vars. contains( & y) ) ;
372+ assert ! ( solver. unsolved_vars. contains( & z) ) ;
360373 }
361374}
0 commit comments