Skip to content

Commit 08ce040

Browse files
Copilotrelf
andauthored
Parallelize distance calculations and eliminate allocations in hot paths (#328)
* Initial plan * Optimize distance calculations and reduce allocations Co-authored-by: relf <[email protected]> * Remove unused import Co-authored-by: relf <[email protected]> * Cleanup * Simplify pdist and warn about the order * Cleanup --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: relf <[email protected]> Co-authored-by: relf <[email protected]>
1 parent 57dce83 commit 08ce040

File tree

5 files changed

+60
-32
lines changed

5 files changed

+60
-32
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/doe/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ ndarray-rand.workspace = true
2828
ndarray-stats.workspace = true
2929
num-traits.workspace = true
3030
rand_xoshiro.workspace = true
31+
rayon.workspace = true
3132
serde = { version = "1", optional = true }
3233

3334
[dev-dependencies]

crates/doe/src/lhs.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ impl<F: Float, R: Rng> Lhs<F, R> {
184184
lhs_best
185185
}
186186

187-
fn _phip(&self, lhs: &ArrayBase<impl Data<Elem = F>, Ix2>, p: F) -> F {
187+
fn _phip(&self, lhs: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>, p: F) -> F {
188188
F::powf(pdist(lhs).mapv(|v| F::powf(v, -p)).sum(), F::one() / p)
189189
}
190190

@@ -208,8 +208,14 @@ impl<F: Float, R: Rng> Lhs<F, R> {
208208
let mut dist1 = cdist(&x.slice(s![i1..i1 + 1, ..]), &x_rest);
209209
let mut dist2 = cdist(&x.slice(s![i2..i2 + 1, ..]), &x_rest);
210210

211-
let m1 = (x_rest.column(k).to_owned() - x[[i1, k]]).map(|v| *v * *v);
212-
let m2 = (x_rest.column(k).to_owned() - x[[i2, k]]).map(|v| *v * *v);
211+
let m1 = x_rest.column(k).mapv(|v| {
212+
let diff = v - x[[i1, k]];
213+
diff * diff
214+
});
215+
let m2 = x_rest.column(k).mapv(|v| {
216+
let diff = v - x[[i2, k]];
217+
diff * diff
218+
});
213219

214220
let two = F::cast(2.);
215221
let mut d1 = dist1.mapv(|v| v * v) - &m1 + &m2;

crates/doe/src/utils.rs

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
11
use linfa::Float;
2-
use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
2+
use ndarray::{Array, Array1, Array2, ArrayBase, Data, Ix2, Zip};
33
use ndarray_stats::DeviationExt;
4+
use rayon::prelude::*;
45

5-
pub fn pdist<F: Float>(x: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Array1<F> {
6+
/// Computes the pairwise distances between rows of a 2D-array using parallel processing
7+
/// Warning : The result is expected to b=e used in a context where order does not matter
8+
/// (e.g., get min distance) as the order of distances depends on the order of parallel execution
9+
pub fn pdist<F: Float>(x: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>) -> Array1<F> {
610
let nrows = x.nrows();
7-
let size: usize = (nrows - 1) * nrows / 2;
8-
let mut res: Array1<F> = Array1::zeros(size);
9-
let mut k = 0;
10-
for i in 0..nrows {
11-
for j in (i + 1)..nrows {
11+
12+
// Parallelize the outer loop for better performance
13+
let pairs: Vec<_> = (0..nrows)
14+
.flat_map(|i| ((i + 1)..nrows).map(move |j| (i, j)))
15+
.collect();
16+
17+
let distances: Vec<_> = pairs
18+
.par_iter()
19+
.map(|&(i, j)| {
1220
let a = x.row(i);
1321
let b = x.row(j);
14-
res[k] = F::cast(a.l2_dist(&b).unwrap());
15-
k += 1;
16-
}
17-
}
18-
res
22+
F::cast(a.l2_dist(&b).unwrap())
23+
})
24+
.collect();
25+
26+
Array::from_vec(distances)
1927
}
2028

29+
/// Computes the pairwise distances between rows of two 2D arrays using parallel processing
30+
/// The resulting array has shape (ma, mb) where ma is the number of rows in xa and mb is the number of rows in xb
2131
pub fn cdist<F: Float>(
22-
xa: &ArrayBase<impl Data<Elem = F>, Ix2>,
23-
xb: &ArrayBase<impl Data<Elem = F>, Ix2>,
32+
xa: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
33+
xb: &ArrayBase<impl Data<Elem = F> + Sync, Ix2>,
2434
) -> Array2<F> {
2535
let ma = xa.nrows();
2636
let mb = xb.nrows();
@@ -29,14 +39,15 @@ pub fn cdist<F: Float>(
2939
if na != nb {
3040
panic!("cdist: operands should have same nb of columns. Found {na} and {nb}");
3141
}
42+
3243
let mut res = Array2::zeros((ma, mb));
33-
for i in 0..ma {
34-
for j in 0..mb {
35-
let a = xa.row(i);
36-
let b = xb.row(j);
37-
res[[i, j]] = F::cast(a.l2_dist(&b).unwrap());
38-
}
39-
}
44+
Zip::from(res.rows_mut())
45+
.and(xa.rows())
46+
.par_for_each(|mut row_res, row_a| {
47+
for (j, row_b) in xb.rows().into_iter().enumerate() {
48+
row_res[j] = F::cast(row_a.l2_dist(&row_b).unwrap());
49+
}
50+
});
4051

4152
res
4253
}

crates/gp/src/utils.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use linfa::Float;
2-
use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, s};
2+
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, s};
33
#[cfg(feature = "serializable")]
44
use serde::{Deserialize, Serialize};
55

@@ -111,13 +111,22 @@ pub fn pairwise_differences<F: Float>(
111111
y: &ArrayBase<impl Data<Elem = F>, Ix2>,
112112
) -> Array2<F> {
113113
assert!(x.ncols() == y.ncols());
114-
let x3 = x.to_owned().insert_axis(Axis(1));
115-
let y3 = y.to_owned().insert_axis(Axis(0));
116-
let d = x3 - y3;
117-
let n = d.len();
118-
let res = Array::from_iter(d.iter().cloned());
119-
res.into_shape_with_order((n / x.ncols(), x.ncols()))
120-
.unwrap()
114+
115+
let nx = x.nrows();
116+
let ny = y.nrows();
117+
let ncols = x.ncols();
118+
let mut result = Array2::zeros((nx * ny, ncols));
119+
120+
for (i, x_row) in x.rows().into_iter().enumerate() {
121+
for (j, y_row) in y.rows().into_iter().enumerate() {
122+
let idx = i * ny + j;
123+
for k in 0..ncols {
124+
result[[idx, k]] = x_row[k] - y_row[k];
125+
}
126+
}
127+
}
128+
129+
result
121130
}
122131

123132
/// Computes differences between x and each element of y

0 commit comments

Comments
 (0)