Skip to content

Commit 5221f96

Browse files
committed
Adapt barycentric weights to the lifted protocol.
1 parent 47186dd commit 5221f96

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

crates/stwo/src/prover/pcs/mod.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
101101
pub fn build_weights_hash_map(
102102
&self,
103103
sampled_points: &TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
104+
max_log_size: u32,
104105
) -> WeightsHashMap<B>
105106
where
106107
Col<B, SecureField>: Send + Sync,
@@ -120,16 +121,20 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
120121
};
121122

122123
let log_size = poly.evals.domain.log_size();
123-
124+
// For each sample point, compute the weights needed to evaluate the polynomial at
125+
// the folded sample point.
126+
// TODO(Leo): the computation `point.repeated_double(max_log_size - log_size)` is
127+
// likely repeated a bunch of times in a typical flat air. Consider moving it
128+
// outside the loop.
124129
#[cfg(not(feature = "parallel"))]
125-
points
126-
.iter()
127-
.for_each(|&point| compute_weights((log_size, point)));
130+
points.iter().for_each(|&point| {
131+
compute_weights((log_size, point.repeated_double(max_log_size - log_size)))
132+
});
128133

129134
#[cfg(feature = "parallel")]
130-
points
131-
.par_iter()
132-
.for_each(|&point| compute_weights((log_size, point)));
135+
points.par_iter().for_each(|&point| {
136+
compute_weights((log_size, point.repeated_double(max_log_size - log_size)))
137+
});
133138
});
134139

135140
weights_dashmap
@@ -147,12 +152,13 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
147152
class = "EvaluateOutOfDomain"
148153
)
149154
.entered();
155+
156+
let max_log_size = self.trees.last().unwrap().commitment.layers.len() as u32 - 1;
150157
let weights_hash_map = if self.store_polynomials_coefficients {
151158
None
152159
} else {
153-
Some(self.build_weights_hash_map(&sampled_points))
160+
Some(self.build_weights_hash_map(&sampled_points, max_log_size))
154161
};
155-
let max_log_size = self.trees.last().unwrap().commitment.layers.len() as u32 - 1;
156162
let samples: TreeVec<Vec<Vec<PointSample>>> = self
157163
.polynomials()
158164
.zip_cols(&sampled_points)

crates/stwo/src/prover/pcs/quotient_ops.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ mod tests {
218218
polys
219219
}
220220

221-
fn prove_and_verify_pcs<B: BackendForChannel<Blake2sMerkleChannel>>(
222-
) -> Result<(), VerificationError> {
221+
fn prove_and_verify_pcs<
222+
B: BackendForChannel<Blake2sMerkleChannel>,
223+
const STORE_COEFFS: bool,
224+
>() -> Result<(), VerificationError> {
223225
const N_COLS: usize = 10;
224226
const LIFTING_LOG_SIZE: u32 = 8;
225227

@@ -231,7 +233,9 @@ mod tests {
231233
);
232234
let mut commitment_scheme =
233235
CommitmentSchemeProver::<B, Blake2sMerkleChannel>::new(config, &twiddles);
234-
commitment_scheme.set_store_polynomials_coefficients();
236+
if STORE_COEFFS {
237+
commitment_scheme.set_store_polynomials_coefficients();
238+
}
235239
let polys = prepare_polys::<B, N_COLS, LIFTING_LOG_SIZE>();
236240
let sizes = polys.iter().map(|poly| poly.log_size()).collect_vec();
237241

@@ -261,10 +265,14 @@ mod tests {
261265

262266
#[test]
263267
fn test_pcs_prove_and_verify_cpu() {
264-
assert!(prove_and_verify_pcs::<CpuBackend>().is_ok());
268+
assert!(prove_and_verify_pcs::<CpuBackend, true>().is_ok());
265269
}
266270
#[test]
267271
fn test_pcs_prove_and_verify_simd() {
268-
assert!(prove_and_verify_pcs::<SimdBackend>().is_ok());
272+
assert!(prove_and_verify_pcs::<SimdBackend, true>().is_ok());
273+
}
274+
#[test]
275+
fn test_pcs_prove_and_verify_simd_with_barycentric() {
276+
assert!(prove_and_verify_pcs::<SimdBackend, false>().is_ok());
269277
}
270278
}

0 commit comments

Comments
 (0)