Skip to content

Commit 2e57d72

Browse files
Parallelize OOD evaluation. (#1306)
1 parent ad5aad5 commit 2e57d72

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

crates/stwo/src/core/pcs/utils.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ impl<T> TreeVec<ColumnVec<T>> {
7777
)
7878
}
7979

80+
#[cfg(feature = "parallel")]
81+
pub fn par_map_cols<U, F>(self, f: F) -> TreeVec<ColumnVec<U>>
82+
where
83+
T: Send,
84+
U: Send,
85+
F: Fn(T) -> U + Sync + Send,
86+
{
87+
use rayon::iter::{IntoParallelIterator, ParallelIterator};
88+
TreeVec(
89+
self.0
90+
.into_par_iter()
91+
.map(|column| column.into_par_iter().map(&f).collect::<Vec<_>>())
92+
.collect(),
93+
)
94+
}
95+
8096
/// Zips two [`TreeVec<ColumnVec<T>>`] with the same structure (number of columns in each tree).
8197
/// The resulting [`TreeVec<ColumnVec<T>>`] has the same structure, with each value being a
8298
/// tuple of the corresponding values from the input [`TreeVec<ColumnVec<T>>`].

crates/stwo/src/prover/pcs/mod.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,30 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
150150
} else {
151151
Some(self.build_weights_hash_map(&sampled_points))
152152
};
153+
154+
// Lambda that evaluates a polynomial on a collection of circle points and returns a vector
155+
// of point samples.
156+
let eval_at_points = |(poly, points): (&Poly<B>, &Vec<CirclePoint<SecureField>>)| {
157+
points
158+
.iter()
159+
.map(|&point| PointSample {
160+
point,
161+
value: poly.eval_at_point(point, weights_hash_map.as_ref()),
162+
})
163+
.collect_vec()
164+
};
165+
166+
#[cfg(not(feature = "parallel"))]
153167
let samples: TreeVec<Vec<Vec<PointSample>>> = self
154168
.polynomials()
155169
.zip_cols(&sampled_points)
156-
.map_cols(|(poly, points)| {
157-
points
158-
.iter()
159-
.map(|&point| PointSample {
160-
point,
161-
value: poly.eval_at_point(point, weights_hash_map.as_ref()),
162-
})
163-
.collect_vec()
164-
});
170+
.map_cols(eval_at_points);
171+
#[cfg(feature = "parallel")]
172+
let samples: TreeVec<Vec<Vec<PointSample>>> = self
173+
.polynomials()
174+
.zip_cols(&sampled_points)
175+
.par_map_cols(eval_at_points);
176+
165177
span.exit();
166178
let sampled_values = samples
167179
.as_cols_ref()

0 commit comments

Comments
 (0)