Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/stwo/benches/pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use stwo::core::poly::circle::CanonicCoset;
use stwo::core::vcs_lifted::blake2_merkle::Blake2sMerkleChannel;
use stwo::prover::backend::simd::SimdBackend;
use stwo::prover::backend::{BackendForChannel, CpuBackend};
use stwo::prover::mempool::BaseColumnPool;
use stwo::prover::poly::circle::CircleEvaluation;
use stwo::prover::poly::twiddles::TwiddleTree;
use stwo::prover::poly::BitReversedOrder;
Expand All @@ -35,6 +36,7 @@ fn benched_fn<B: BackendForChannel<Blake2sMerkleChannel>>(
twiddles,
false,
None,
&mut BaseColumnPool::new(),
);
}

Expand Down
29 changes: 29 additions & 0 deletions crates/stwo/src/core/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,38 @@
use core::iter::Peekable;
use core::ops::{Deref, DerefMut};

use std_shims::Vec;

use super::fields::Field;

/// An enum that either borrows mutably or owns a value.
/// Useful when a struct can optionally receive an external `&mut T` but also needs a fallback owned
/// instance.
pub enum MaybeOwned<'a, T> {
Borrowed(&'a mut T),
Owned(T),
}

impl<T> Deref for MaybeOwned<'_, T> {
type Target = T;

fn deref(&self) -> &T {
match self {
MaybeOwned::Borrowed(r) => r,
MaybeOwned::Owned(ref v) => v,
}
}
}

impl<T> DerefMut for MaybeOwned<'_, T> {
fn deref_mut(&mut self) -> &mut T {
match self {
MaybeOwned::Borrowed(r) => r,
MaybeOwned::Owned(ref mut v) => v,
}
}
}

pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
fn assign(self, other: impl IntoIterator<Item = T>)
where
Expand Down
38 changes: 30 additions & 8 deletions crates/stwo/src/prover/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,40 +191,62 @@ impl PolyOps for CpuBackend {
poly: &CircleCoefficients<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
let buffer = vec![BaseField::zero(); domain.size()];
Self::evaluate_into(poly, domain, twiddles, buffer)
}

fn evaluate_into(
poly: &CircleCoefficients<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
mut buffer: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
assert!(domain.half_coset.is_doubling_of(twiddles.root_coset));
assert_eq!(buffer.len(), domain.size());

let mut values = poly.extend(domain.log_size()).coeffs;
// Copy extended coefficients into the buffer.
let poly_len = poly.coeffs.len();
buffer[..poly_len].copy_from_slice(&poly.coeffs);
for v in &mut buffer[poly_len..] {
*v = BaseField::zero();
}

if domain.log_size() == 1 {
let (mut v0, mut v1) = (values[0], values[1]);
let (mut v0, mut v1) = (buffer[0], buffer[1]);
butterfly(&mut v0, &mut v1, domain.half_coset.initial.y);
return CircleEvaluation::new(domain, vec![v0, v1]);
buffer[0] = v0;
buffer[1] = v1;
return CircleEvaluation::new(domain, buffer);
}

if domain.log_size() == 2 {
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
let (mut v0, mut v1, mut v2, mut v3) = (buffer[0], buffer[1], buffer[2], buffer[3]);
let CirclePoint { x, y } = domain.half_coset.initial;
butterfly(&mut v0, &mut v2, x);
butterfly(&mut v1, &mut v3, x);
butterfly(&mut v0, &mut v1, y);
butterfly(&mut v2, &mut v3, -y);
return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]);
buffer[0] = v0;
buffer[1] = v1;
buffer[2] = v2;
buffer[3] = v3;
return CircleEvaluation::new(domain, buffer);
}

let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);

for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() {
for (h, &t) in layer_twiddles.iter().enumerate() {
fft_layer_loop(&mut values, layer + 1, h, t, butterfly);
fft_layer_loop(&mut buffer, layer + 1, h, t, butterfly);
}
}
for (h, t) in circle_twiddles.enumerate() {
fft_layer_loop(&mut values, 0, h, t, butterfly);
fft_layer_loop(&mut buffer, 0, h, t, butterfly);
}

CircleEvaluation::new(domain, values)
CircleEvaluation::new(domain, buffer)
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
Expand Down
33 changes: 16 additions & 17 deletions crates/stwo/src/prover/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,17 @@ impl PolyOps for SimdBackend {
poly: &CircleCoefficients<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// SAFETY: evaluate_into writes all values via FFT before they are read.
let buffer = unsafe { Col::<Self, BaseField>::uninitialized(domain.size()) };
Self::evaluate_into(poly, domain, twiddles, buffer)
}

fn evaluate_into(
poly: &CircleCoefficients<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
mut buffer: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
let _span = span!(Level::TRACE, "", class = "rFFT").entered();
let log_size = domain.log_size();
Expand All @@ -386,6 +397,7 @@ impl PolyOps for SimdBackend {
log_size >= fft_log_size,
"Can only evaluate on larger domains"
);
assert_eq!(buffer.len(), domain.size());

if fft_log_size < MIN_FFT_LOG_SIZE {
let cpu_poly: CircleCoefficients<CpuBackend> =
Expand All @@ -399,16 +411,9 @@ impl PolyOps for SimdBackend {

let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);

// Evaluate on a big domains by evaluating on several subdomains.
// Evaluate on big domains by evaluating on several subdomains.
let log_subdomains = log_size - fft_log_size;

// Allocate the destination buffer without initializing.
let mut values = Vec::with_capacity(domain.size() >> LOG_N_LANES);
#[allow(clippy::uninit_vec)]
unsafe {
values.set_len(domain.size() >> LOG_N_LANES)
};

for i in 0..(1 << log_subdomains) {
// The subdomain twiddles are a slice of the large domain twiddles.
let subdomain_twiddles = (0..(fft_log_size - 1))
Expand All @@ -418,12 +423,12 @@ impl PolyOps for SimdBackend {
})
.collect::<Vec<_>>();

// FFT from the coefficients buffer to the values chunk.
// FFT from the coefficients buffer directly into the provided buffer.
unsafe {
rfft::fft(
transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()),
transmute::<*mut PackedBaseField, *mut u32>(
values[i << (fft_log_size - LOG_N_LANES)
buffer.data[i << (fft_log_size - LOG_N_LANES)
..(i + 1) << (fft_log_size - LOG_N_LANES)]
.as_mut_ptr(),
),
Expand All @@ -433,13 +438,7 @@ impl PolyOps for SimdBackend {
}
}

CircleEvaluation::new(
domain,
BaseColumn {
data: values,
length: domain.size(),
},
)
CircleEvaluation::new(domain, buffer)
}

/// Precomputes the (doubled) twiddles for a given coset tower.
Expand Down
79 changes: 79 additions & 0 deletions crates/stwo/src/prover/mempool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//! Pre-allocated memory pools for the proving pipeline.
//!
//! The [`BaseColumnPool`] manages reusable [`Col<B, BaseField>`] buffers for polynomial evaluation,
//! avoiding repeated allocation/deallocation of large column buffers during proving.

use std::collections::HashMap;

use crate::core::fields::m31::BaseField;
use crate::prover::backend::{Col, Column, ColumnOps};

/// A pool of pre-allocated [`Col<B, BaseField>`] buffers, organized by log_size.
///
/// Used to avoid repeated allocation of evaluation buffers during polynomial commitment.
pub struct BaseColumnPool<B: ColumnOps<BaseField>> {
/// Map from log_size -> stack of available buffers.
pools: HashMap<u32, Vec<Col<B, BaseField>>>,
}

impl<B: ColumnOps<BaseField>> BaseColumnPool<B> {
/// Creates a new empty base column pool.
pub fn new() -> Self {
Self {
pools: HashMap::new(),
}
}

/// Pre-allocates `count` zero-initialized buffers of size `1 << log_size`.
pub fn reserve(&mut self, log_size: u32, count: usize) {
let pool = self.pools.entry(log_size).or_default();
for _ in 0..count {
pool.push(Col::<B, BaseField>::zeros(1 << log_size));
}
}

/// Takes a buffer from the pool for the given `log_size`.
///
/// # Panics
///
/// Panics if no buffer of the requested size is available.
pub fn take(&mut self, log_size: u32) -> Col<B, BaseField> {
self.pools
.get_mut(&log_size)
.and_then(|pool| pool.pop())
.unwrap_or_else(|| {
panic!("BaseColumnPool: no buffer available for log_size={log_size}")
})
}

/// Takes a buffer from the pool, or allocates a new zero-initialized one if none is available.
pub fn take_or_alloc(&mut self, log_size: u32) -> Col<B, BaseField> {
self.pools
.get_mut(&log_size)
.and_then(|pool| pool.pop())
.unwrap_or_else(|| unsafe { Col::<B, BaseField>::uninitialized(1 << log_size) })
}

/// Returns a buffer to the pool. The caller is responsible for ensuring the buffer's log_size
/// matches.
pub fn give_back(&mut self, log_size: u32, buf: Col<B, BaseField>) {
debug_assert_eq!(buf.len(), 1 << log_size);
self.pools.entry(log_size).or_default().push(buf);
}

/// Returns the number of available buffers for a given log_size.
pub fn available(&self, log_size: u32) -> usize {
self.pools.get(&log_size).map_or(0, |pool| pool.len())
}

/// Returns the total number of buffers across all sizes.
pub fn total_available(&self) -> usize {
self.pools.values().map(|pool| pool.len()).sum()
}
}

impl<B: ColumnOps<BaseField>> Default for BaseColumnPool<B> {
fn default() -> Self {
Self::new()
}
}
1 change: 1 addition & 0 deletions crates/stwo/src/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod channel;
pub mod fri;
pub mod line;
pub mod lookups;
pub mod mempool;
pub mod poly;
pub mod secure_column;
pub mod vcs;
Expand Down
22 changes: 22 additions & 0 deletions crates/stwo/src/prover/pcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ use crate::core::pcs::quotients::{
use crate::core::pcs::utils::prepare_preprocessed_query_positions;
use crate::core::pcs::{PcsConfig, TreeSubspan, TreeVec};
use crate::core::poly::circle::CanonicCoset;
use crate::core::utils::MaybeOwned;
use crate::core::vcs_lifted::merkle_hasher::MerkleHasherLifted;
use crate::core::vcs_lifted::verifier::ExtendedMerkleDecommitmentLifted;
use crate::core::ColumnVec;
use crate::prover::air::component_prover::{Poly, Trace, WeightsHashMap};
use crate::prover::backend::{BackendForChannel, Col};
use crate::prover::fri::{FriDecommitResult, FriProver};
use crate::prover::mempool::BaseColumnPool;
use crate::prover::pcs::quotient_ops::compute_fri_quotients;
use crate::prover::poly::circle::{CircleCoefficients, CircleEvaluation};
use crate::prover::poly::twiddles::TwiddleTree;
Expand All @@ -34,6 +36,8 @@ pub struct CommitmentSchemeProver<'a, B: BackendForChannel<MC>, MC: MerkleChanne
pub config: PcsConfig,
twiddles: &'a TwiddleTree<B>,
pub store_polynomials_coefficients: bool,
/// Pre-allocated base field column pool for polynomial evaluation during commit.
pub base_column_pool: MaybeOwned<'a, BaseColumnPool<B>>,
}
impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> {
/// Creates a new empty commitment scheme prover with the given configuration and twiddles. The
Expand All @@ -44,6 +48,21 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
config,
twiddles,
store_polynomials_coefficients: false,
base_column_pool: MaybeOwned::Owned(BaseColumnPool::new()),
}
}

pub fn with_memory_pool(
config: PcsConfig,
twiddles: &'a TwiddleTree<B>,
base_column_pool: &'a mut BaseColumnPool<B>,
) -> Self {
CommitmentSchemeProver {
trees: TreeVec::default(),
config,
twiddles,
store_polynomials_coefficients: false,
base_column_pool: MaybeOwned::Borrowed(base_column_pool),
}
}

Expand All @@ -62,6 +81,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
self.twiddles,
self.store_polynomials_coefficients,
self.config.lifting_log_size,
&mut self.base_column_pool,
);
self.trees.push(tree);
}
Expand Down Expand Up @@ -319,13 +339,15 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
twiddles: &TwiddleTree<B>,
store_polynomials_coefficients: bool,
lifting_log_size: Option<u32>,
base_column_pool: &mut BaseColumnPool<B>,
) -> Self {
let span = span!(Level::INFO, "Extension").entered();
let polynomials = B::evaluate_polynomials(
polynomials,
log_blowup_factor,
twiddles,
store_polynomials_coefficients,
base_column_pool,
);
span.exit();

Expand Down
Loading