Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/stwo/src/core/pcs/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,20 @@ impl<T> TreeVec<ColumnVec<Vec<T>>> {

pub fn prepare_preprocessed_query_positions(
query_positions: &[usize],
max_log_size: u32,
lifting_log_size: u32,
pp_max_log_size: u32,
) -> Vec<usize> {
if pp_max_log_size == 0 {
return vec![];
};
if max_log_size < pp_max_log_size {
if lifting_log_size < pp_max_log_size {
return query_positions
.iter()
.map(|pos| (pos >> 1 << (pp_max_log_size - max_log_size + 1)) + (pos & 1))
.map(|pos| (pos >> 1 << (pp_max_log_size - lifting_log_size + 1)) + (pos & 1))
.collect();
}
query_positions
.iter()
.map(|pos| (pos >> (max_log_size - pp_max_log_size + 1) << 1) + (pos & 1))
.map(|pos| (pos >> (lifting_log_size - pp_max_log_size + 1) << 1) + (pos & 1))
.collect()
}
10 changes: 5 additions & 5 deletions crates/stwo/src/prover/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,26 @@ impl QuotientOps for CpuBackend {
fn compute_quotients_and_combine(
accumulations: Vec<AccumulatedNumerators<Self>>,
) -> SecureEvaluation<Self, BitReversedOrder> {
let max_log_size = accumulations
let lifting_log_size = accumulations
.iter()
.map(|x| x.partial_numerators_acc.len())
.max()
.unwrap()
.ilog2();

let domain = CanonicCoset::new(max_log_size).circle_domain();
let domain = CanonicCoset::new(lifting_log_size).circle_domain();
let mut quotients: SecureColumnByCoords<CpuBackend> =
unsafe { SecureColumnByCoords::uninitialized(1 << max_log_size) };
unsafe { SecureColumnByCoords::uninitialized(1 << lifting_log_size) };
let sample_points: Vec<CirclePoint<SecureField>> =
accumulations.iter().map(|x| x.sample_point).collect();
// Populate `quotients`.
for row in 0..quotients.len() {
let domain_point = domain.at(bit_reverse_index(row, max_log_size));
let domain_point = domain.at(bit_reverse_index(row, lifting_log_size));
let inverses = denominator_inverses(&sample_points, domain_point);
let mut quotient = SecureField::zero();
for (acc, den_inv) in accumulations.iter().zip_eq(inverses) {
let mut full_numerator = SecureField::zero();
let log_ratio = max_log_size - acc.partial_numerators_acc.len().ilog2();
let log_ratio = lifting_log_size - acc.partial_numerators_acc.len().ilog2();
let lifted_idx = (row >> (log_ratio + 1) << 1) + (row & 1);

full_numerator += acc.partial_numerators_acc.at(lifted_idx)
Expand Down
1 change: 0 additions & 1 deletion crates/stwo/src/prover/backend/simd/poseidon252_lifted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::core::vcs_lifted::merkle_hasher::MerkleHasherLifted;
use crate::core::vcs_lifted::poseidon252_merkle::{
poseidon_finalize, poseidon_update, Poseidon252MerkleHasher, ELEMENTS_IN_BUFFER,
};
#[cfg(feature = "parallel")]
use crate::prover::backend::simd::m31::N_LANES;
use crate::prover::backend::simd::SimdBackend;
use crate::prover::backend::{Col, Column, CpuBackend};
Expand Down
14 changes: 7 additions & 7 deletions crates/stwo/src/prover/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,36 +74,36 @@ impl QuotientOps for SimdBackend {
fn compute_quotients_and_combine(
accumulations: Vec<AccumulatedNumerators<Self>>,
) -> SecureEvaluation<Self, BitReversedOrder> {
let max_log_size = accumulations
let lifting_log_size = accumulations
.iter()
.map(|x| x.partial_numerators_acc.len())
.max()
.unwrap()
.ilog2();

let domain = CanonicCoset::new(max_log_size).circle_domain();
let domain = CanonicCoset::new(lifting_log_size).circle_domain();
let domain_points: Vec<CirclePoint<PackedBaseField>> =
CircleDomainBitRevIterator::new(domain).collect();
let mut quotients: SecureColumnByCoords<SimdBackend> =
unsafe { SecureColumnByCoords::uninitialized(1 << max_log_size) };
unsafe { SecureColumnByCoords::uninitialized(1 << lifting_log_size) };
let sample_points: Vec<CirclePoint<SecureField>> =
accumulations.iter().map(|x| x.sample_point).collect();
let denominators_inverses = denominator_inverses(&sample_points, domain);

// Populate `quotients`.
// TODO(Leo): make chunk size configurable.
#[cfg(not(feature = "parallel"))]
let iter = quotients.iter_mut(1).enumerate();
let iter = quotients.chunks_mut(1).enumerate();

#[cfg(feature = "parallel")]
let iter = quotients.chunks_mut(1).enumerate();
let iter = quotients.par_chunks_mut(1).enumerate();

iter.for_each(|(domain_idx, mut value_dst)| {
let mut quotient = PackedSecureField::zero();
for (acc, den_inv) in accumulations.iter().zip_eq(denominators_inverses.iter()) {
let mut full_numerator = PackedSecureField::zero();

let log_ratio = max_log_size - acc.partial_numerators_acc.len().ilog2();
let log_ratio = lifting_log_size - acc.partial_numerators_acc.len().ilog2();
let lifted_partial_numerator =
PackedSecureField::from_packed_m31s(std::array::from_fn(|j| {
let lifted_simd = to_lifted_simd(
Expand Down Expand Up @@ -147,7 +147,7 @@ fn denominator_inverses(
let domain_points = CircleDomainBitRevIterator::new(domain);

#[cfg(not(feature = "parallel"))]
let iter = domain_points.iter();
let iter = domain_points;

#[cfg(feature = "parallel")]
let iter = domain_points.par_iter();
Expand Down
28 changes: 17 additions & 11 deletions crates/stwo/src/prover/pcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
pub fn build_weights_hash_map(
&self,
sampled_points: &TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
lifting_log_size: u32,
) -> WeightsHashMap<B>
where
Col<B, SecureField>: Send + Sync,
Expand All @@ -120,16 +121,20 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
};

let log_size = poly.evals.domain.log_size();

// For each sample point, compute the weights needed to evaluate the polynomial at
// the folded sample point.
// TODO(Leo): the computation `point.repeated_double(lifting_log_size - log_size)`
// is likely repeated a bunch of times in a typical flat air.
// Consider moving it outside the loop.
#[cfg(not(feature = "parallel"))]
points
.iter()
.for_each(|&point| compute_weights((log_size, point)));
points.iter().for_each(|&point| {
compute_weights((log_size, point.repeated_double(lifting_log_size - log_size)))
});

#[cfg(feature = "parallel")]
points
.par_iter()
.for_each(|&point| compute_weights((log_size, point)));
points.par_iter().for_each(|&point| {
compute_weights((log_size, point.repeated_double(lifting_log_size - log_size)))
});
});

weights_dashmap
Expand All @@ -147,12 +152,13 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
class = "EvaluateOutOfDomain"
)
.entered();

let lifting_log_size = self.trees.last().unwrap().commitment.layers.len() as u32 - 1;
let weights_hash_map = if self.store_polynomials_coefficients {
None
} else {
Some(self.build_weights_hash_map(&sampled_points))
Some(self.build_weights_hash_map(&sampled_points, lifting_log_size))
};
let max_log_size = self.trees.last().unwrap().commitment.layers.len() as u32 - 1;
let samples: TreeVec<Vec<Vec<PointSample>>> = self
.polynomials()
.zip_cols(&sampled_points)
Expand All @@ -162,7 +168,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
.map(|&point| PointSample {
point,
value: poly.eval_at_point(
point.repeated_double(max_log_size - poly.evals.domain.log_size()),
point.repeated_double(lifting_log_size - poly.evals.domain.log_size()),
weights_hash_map.as_ref(),
),
})
Expand Down Expand Up @@ -203,7 +209,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
// Build the query position tree.
let preprocessed_query_positions = prepare_preprocessed_query_positions(
&query_positions,
max_log_size,
lifting_log_size,
self.trees[0].commitment.layers.len() as u32 - 1,
);
let query_positions_tree = TreeVec::new(
Expand Down
18 changes: 13 additions & 5 deletions crates/stwo/src/prover/pcs/quotient_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,10 @@ mod tests {
polys
}

fn prove_and_verify_pcs<B: BackendForChannel<Blake2sMerkleChannel>>(
) -> Result<(), VerificationError> {
fn prove_and_verify_pcs<
B: BackendForChannel<Blake2sMerkleChannel>,
const STORE_COEFFS: bool,
>() -> Result<(), VerificationError> {
const N_COLS: usize = 10;
const LIFTING_LOG_SIZE: u32 = 8;

Expand All @@ -231,7 +233,9 @@ mod tests {
);
let mut commitment_scheme =
CommitmentSchemeProver::<B, Blake2sMerkleChannel>::new(config, &twiddles);
commitment_scheme.set_store_polynomials_coefficients();
if STORE_COEFFS {
commitment_scheme.set_store_polynomials_coefficients();
}
let polys = prepare_polys::<B, N_COLS, LIFTING_LOG_SIZE>();
let sizes = polys.iter().map(|poly| poly.log_size()).collect_vec();

Expand Down Expand Up @@ -261,10 +265,14 @@ mod tests {

#[test]
fn test_pcs_prove_and_verify_cpu() {
assert!(prove_and_verify_pcs::<CpuBackend>().is_ok());
assert!(prove_and_verify_pcs::<CpuBackend, true>().is_ok());
}
#[test]
fn test_pcs_prove_and_verify_simd() {
assert!(prove_and_verify_pcs::<SimdBackend>().is_ok());
assert!(prove_and_verify_pcs::<SimdBackend, true>().is_ok());
}
#[test]
fn test_pcs_prove_and_verify_simd_with_barycentric() {
assert!(prove_and_verify_pcs::<SimdBackend, false>().is_ok());
}
}
Loading