Skip to content

Commit 704baac

Browse files
authored
feat(solver): compact SVD matrices by removing unused columns (#123)
1 parent 6562a12 commit 704baac

File tree

3 files changed

+43
-30
lines changed

3 files changed

+43
-30
lines changed

core/compiler/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ edition = "2024"
55

66
[dependencies]
77
derive-where = { version = "1", features = ["serde"] }
8-
nalgebra = "0.34"
8+
nalgebra = { version = "0.34", features = ["sparse"] }
99
klayout-lyp = "0.1.1"
1010
gds21 = "0.2"
1111

core/compiler/src/compile.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2263,7 +2263,7 @@ impl<'a> ExecPass<'a> {
22632263
if require_progress && !progress {
22642264
let state = self.cell_state_mut(cell_id);
22652265
if state.unsolved_vars.is_none() {
2266-
state.unsolved_vars = Some(state.solver.unsolved_vars());
2266+
state.unsolved_vars = Some(state.solver.unsolved_vars().clone());
22672267
self.errors.push(ExecError {
22682268
span: None,
22692269
cell: cell_id,

core/compiler/src/solver.rs

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use approx::{relative_eq, relative_ne};
22
use indexmap::{IndexMap, IndexSet};
3-
use itertools::{Either, Itertools};
4-
use nalgebra::{DMatrix, DVector};
3+
use itertools::{Either, Itertools, multiunzip};
4+
use nalgebra::{CsMatrix, DMatrix, DVector};
55
use serde::{Deserialize, Serialize};
66

77
const EPSILON: f64 = 1e-8;
@@ -17,7 +17,9 @@ pub struct Solver {
1717
next_constraint: ConstraintId,
1818
constraints: IndexMap<ConstraintId, LinearExpr>,
1919
var_to_constraints: IndexMap<Var, IndexSet<ConstraintId>>,
20+
// Solved and unsolved vars are separate to reduce overhead of many solved variables.
2021
solved_vars: IndexMap<Var, f64>,
22+
unsolved_vars: IndexSet<Var>,
2123
inconsistent_constraints: IndexSet<ConstraintId>,
2224
invalid_rounding: IndexSet<Var>,
2325
}
@@ -33,22 +35,21 @@ impl Solver {
3335

3436
pub fn new_var(&mut self) -> Var {
3537
let var = Var(self.next_var);
38+
self.unsolved_vars.insert(var);
3639
self.next_var += 1;
3740
var
3841
}
3942

4043
/// Returns true if all variables have been solved.
4144
pub fn fully_solved(&self) -> bool {
42-
self.solved_vars.len() == self.next_var as usize
45+
self.unsolved_vars.is_empty()
4346
}
4447

4548
pub fn force_solution(&mut self) {
4649
while !self.fully_solved() {
4750
// Find any unsolved variable and constrain it to equal 0.
48-
let v = (0..self.next_var)
49-
.find(|&i| !self.solved_vars.contains_key(&Var(i)))
50-
.unwrap();
51-
self.constrain_eq0(LinearExpr::from(Var(v)));
51+
let v = self.unsolved_vars.first().unwrap();
52+
self.constrain_eq0(LinearExpr::from(*v));
5253
self.solve();
5354
}
5455
}
@@ -63,8 +64,13 @@ impl Solver {
6364
&self.invalid_rounding
6465
}
6566

66-
pub fn unsolved_vars(&self) -> IndexSet<Var> {
67-
IndexSet::from_iter((0..self.next_var).map(Var).filter(|&v| !self.is_solved(v)))
67+
pub fn unsolved_vars(&self) -> &IndexSet<Var> {
68+
&self.unsolved_vars
69+
}
70+
71+
pub fn solve_var(&mut self, var: Var, val: f64) {
72+
self.solved_vars.insert(var, val);
73+
self.unsolved_vars.swap_remove(&var);
6874
}
6975

7076
/// Constrains the value of `expr` to 0.
@@ -109,7 +115,7 @@ impl Solver {
109115
if relative_ne!(val, rounded_val, epsilon = EPSILON) {
110116
self.invalid_rounding.insert(var);
111117
}
112-
self.solved_vars.insert(var, rounded_val);
118+
self.solve_var(var, rounded_val);
113119
}
114120
self.constraints.swap_remove(&constraint_id);
115121
for constraint in self
@@ -127,16 +133,29 @@ impl Solver {
127133

128134
/// Solves for as many variables as possible and substitutes their values into existing constraints.
129135
/// Deletes constraints that no longer contain unsolved variables.
136+
///
137+
/// Constraints should be simplified before this function is invoked.
130138
pub fn solve(&mut self) {
131-
let n_vars = self.next_var as usize;
139+
// Snapshot unsolved variables before solving.
140+
let unsolved_vars = self.unsolved_vars.clone();
141+
let n_vars = unsolved_vars.len();
132142
if n_vars == 0 || self.constraints.is_empty() {
133143
return;
134144
}
135-
let a = DMatrix::from_row_iterator(
145+
let (i, j, val): (Vec<_>, Vec<_>, Vec<_>) =
146+
multiunzip(self.constraints.values().enumerate().flat_map(|(i, expr)| {
147+
expr.coeffs.iter().map({
148+
let unsolved_vars = &unsolved_vars;
149+
move |(coeff, var)| (i, unsolved_vars.get_index_of(var).unwrap(), *coeff)
150+
})
151+
}));
152+
let a = DMatrix::from(CsMatrix::from_triplet(
136153
self.constraints.len(),
137154
n_vars,
138-
self.constraints.values().flat_map(|c| c.coeff_vec(n_vars)),
139-
);
155+
&i,
156+
&j,
157+
&val,
158+
));
140159
let b = DVector::from_iterator(
141160
self.constraints.len(),
142161
self.constraints.values().map(|c| -c.constant),
@@ -151,12 +170,11 @@ impl Solver {
151170
let vt_recons = vt.rows(0, r);
152171
let sol = svd.solve(&b, EPSILON).unwrap();
153172

154-
for i in 0..self.next_var {
155-
let recons = (vt_recons.transpose() * vt_recons.column(i as usize))[((i as usize), 0)];
156-
if !self.solved_vars.contains_key(&Var(i))
157-
&& relative_eq!(recons, 1., epsilon = EPSILON)
158-
{
159-
self.solved_vars.insert(Var(i), sol[(i as usize, 0)]);
173+
for var in &unsolved_vars {
174+
let i = unsolved_vars.get_index_of(var).unwrap();
175+
let recons = (vt_recons.transpose() * vt_recons.column(i))[(i, 0)];
176+
if relative_eq!(recons, 1., epsilon = EPSILON) {
177+
self.solve_var(*var, sol[(i, 0)]);
160178
}
161179
}
162180
for (id, constraint) in self.constraints.iter_mut() {
@@ -209,14 +227,6 @@ pub struct LinearExpr {
209227
}
210228

211229
impl LinearExpr {
212-
pub fn coeff_vec(&self, n_vars: usize) -> Vec<f64> {
213-
let mut out = vec![0.; n_vars];
214-
for (val, var) in &self.coeffs {
215-
out[var.0 as usize] += *val;
216-
}
217-
out
218-
}
219-
220230
pub fn add(lhs: impl Into<LinearExpr>, rhs: impl Into<LinearExpr>) -> Self {
221231
lhs.into() + rhs.into()
222232
}
@@ -357,5 +367,8 @@ mod tests {
357367
assert_relative_eq!(*solver.solved_vars.get(&x).unwrap(), 5., epsilon = EPSILON);
358368
assert_relative_eq!(*solver.solved_vars.get(&y).unwrap(), 5., epsilon = EPSILON);
359369
assert!(!solver.solved_vars.contains_key(&z));
370+
assert!(!solver.unsolved_vars.contains(&x));
371+
assert!(!solver.unsolved_vars.contains(&y));
372+
assert!(solver.unsolved_vars.contains(&z));
360373
}
361374
}

0 commit comments

Comments
 (0)