Skip to content

Commit 0ee003a

Browse files
Use BaseColumnPool in CommitmentSchemeProver::commit to avoid allocation during polynomial evaluation.
Add BaseColumnPool for Col<B, BaseField> buffers, add evaluate_into to PolyOps trait (CPU and SIMD backends), and thread the pool through commit → CommitmentTreeProver::new → evaluate_polynomials. Buffers are pre-taken before parallel iteration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9593293 commit 0ee003a

File tree

7 files changed

+175
-32
lines changed

7 files changed

+175
-32
lines changed

crates/stwo/benches/pcs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use stwo::core::poly::circle::CanonicCoset;
99
use stwo::core::vcs_lifted::blake2_merkle::Blake2sMerkleChannel;
1010
use stwo::prover::backend::simd::SimdBackend;
1111
use stwo::prover::backend::{BackendForChannel, CpuBackend};
12+
use stwo::prover::mempool::BaseColumnPool;
1213
use stwo::prover::poly::circle::CircleEvaluation;
1314
use stwo::prover::poly::twiddles::TwiddleTree;
1415
use stwo::prover::poly::BitReversedOrder;
@@ -35,6 +36,7 @@ fn benched_fn<B: BackendForChannel<Blake2sMerkleChannel>>(
3536
twiddles,
3637
false,
3738
None,
39+
&mut BaseColumnPool::new(),
3840
);
3941
}
4042

crates/stwo/src/prover/backend/cpu/circle.rs

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,40 +191,62 @@ impl PolyOps for CpuBackend {
191191
poly: &CircleCoefficients<Self>,
192192
domain: CircleDomain,
193193
twiddles: &TwiddleTree<Self>,
194+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
195+
let buffer = vec![BaseField::zero(); domain.size()];
196+
Self::evaluate_into(poly, domain, twiddles, buffer)
197+
}
198+
199+
fn evaluate_into(
200+
poly: &CircleCoefficients<Self>,
201+
domain: CircleDomain,
202+
twiddles: &TwiddleTree<Self>,
203+
mut buffer: Col<Self, BaseField>,
194204
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
195205
assert!(domain.half_coset.is_doubling_of(twiddles.root_coset));
206+
assert_eq!(buffer.len(), domain.size());
196207

197-
let mut values = poly.extend(domain.log_size()).coeffs;
208+
// Copy extended coefficients into the buffer.
209+
let poly_len = poly.coeffs.len();
210+
buffer[..poly_len].copy_from_slice(&poly.coeffs);
211+
for v in &mut buffer[poly_len..] {
212+
*v = BaseField::zero();
213+
}
198214

199215
if domain.log_size() == 1 {
200-
let (mut v0, mut v1) = (values[0], values[1]);
216+
let (mut v0, mut v1) = (buffer[0], buffer[1]);
201217
butterfly(&mut v0, &mut v1, domain.half_coset.initial.y);
202-
return CircleEvaluation::new(domain, vec![v0, v1]);
218+
buffer[0] = v0;
219+
buffer[1] = v1;
220+
return CircleEvaluation::new(domain, buffer);
203221
}
204222

205223
if domain.log_size() == 2 {
206-
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
224+
let (mut v0, mut v1, mut v2, mut v3) = (buffer[0], buffer[1], buffer[2], buffer[3]);
207225
let CirclePoint { x, y } = domain.half_coset.initial;
208226
butterfly(&mut v0, &mut v2, x);
209227
butterfly(&mut v1, &mut v3, x);
210228
butterfly(&mut v0, &mut v1, y);
211229
butterfly(&mut v2, &mut v3, -y);
212-
return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]);
230+
buffer[0] = v0;
231+
buffer[1] = v1;
232+
buffer[2] = v2;
233+
buffer[3] = v3;
234+
return CircleEvaluation::new(domain, buffer);
213235
}
214236

215237
let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
216238
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);
217239

218240
for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() {
219241
for (h, &t) in layer_twiddles.iter().enumerate() {
220-
fft_layer_loop(&mut values, layer + 1, h, t, butterfly);
242+
fft_layer_loop(&mut buffer, layer + 1, h, t, butterfly);
221243
}
222244
}
223245
for (h, t) in circle_twiddles.enumerate() {
224-
fft_layer_loop(&mut values, 0, h, t, butterfly);
246+
fft_layer_loop(&mut buffer, 0, h, t, butterfly);
225247
}
226248

227-
CircleEvaluation::new(domain, values)
249+
CircleEvaluation::new(domain, buffer)
228250
}
229251

230252
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {

crates/stwo/src/prover/backend/simd/circle.rs

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,17 @@ impl PolyOps for SimdBackend {
378378
poly: &CircleCoefficients<Self>,
379379
domain: CircleDomain,
380380
twiddles: &TwiddleTree<Self>,
381+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
382+
// SAFETY: evaluate_into writes all values via FFT before they are read.
383+
let buffer = unsafe { Col::<Self, BaseField>::uninitialized(domain.size()) };
384+
Self::evaluate_into(poly, domain, twiddles, buffer)
385+
}
386+
387+
fn evaluate_into(
388+
poly: &CircleCoefficients<Self>,
389+
domain: CircleDomain,
390+
twiddles: &TwiddleTree<Self>,
391+
mut buffer: Col<Self, BaseField>,
381392
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
382393
let _span = span!(Level::TRACE, "", class = "rFFT").entered();
383394
let log_size = domain.log_size();
@@ -386,6 +397,7 @@ impl PolyOps for SimdBackend {
386397
log_size >= fft_log_size,
387398
"Can only evaluate on larger domains"
388399
);
400+
assert_eq!(buffer.len(), domain.size());
389401

390402
if fft_log_size < MIN_FFT_LOG_SIZE {
391403
let cpu_poly: CircleCoefficients<CpuBackend> =
@@ -399,16 +411,9 @@ impl PolyOps for SimdBackend {
399411

400412
let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
401413

402-
// Evaluate on a big domains by evaluating on several subdomains.
414+
// Evaluate on big domains by evaluating on several subdomains.
403415
let log_subdomains = log_size - fft_log_size;
404416

405-
// Allocate the destination buffer without initializing.
406-
let mut values = Vec::with_capacity(domain.size() >> LOG_N_LANES);
407-
#[allow(clippy::uninit_vec)]
408-
unsafe {
409-
values.set_len(domain.size() >> LOG_N_LANES)
410-
};
411-
412417
for i in 0..(1 << log_subdomains) {
413418
// The subdomain twiddles are a slice of the large domain twiddles.
414419
let subdomain_twiddles = (0..(fft_log_size - 1))
@@ -418,12 +423,12 @@ impl PolyOps for SimdBackend {
418423
})
419424
.collect::<Vec<_>>();
420425

421-
// FFT from the coefficients buffer to the values chunk.
426+
// FFT from the coefficients buffer directly into the provided buffer.
422427
unsafe {
423428
rfft::fft(
424429
transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()),
425430
transmute::<*mut PackedBaseField, *mut u32>(
426-
values[i << (fft_log_size - LOG_N_LANES)
431+
buffer.data[i << (fft_log_size - LOG_N_LANES)
427432
..(i + 1) << (fft_log_size - LOG_N_LANES)]
428433
.as_mut_ptr(),
429434
),
@@ -433,13 +438,7 @@ impl PolyOps for SimdBackend {
433438
}
434439
}
435440

436-
CircleEvaluation::new(
437-
domain,
438-
BaseColumn {
439-
data: values,
440-
length: domain.size(),
441-
},
442-
)
441+
CircleEvaluation::new(domain, buffer)
443442
}
444443

445444
/// Precomputes the (doubled) twiddles for a given coset tower.

crates/stwo/src/prover/mempool.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//! Pre-allocated memory pools for the proving pipeline.
2+
//!
3+
//! The [`BaseColumnPool`] manages reusable [`Col<B, BaseField>`] buffers for polynomial evaluation,
4+
//! avoiding repeated allocation/deallocation of large column buffers during proving.
5+
6+
use std::collections::HashMap;
7+
8+
use crate::core::fields::m31::BaseField;
9+
use crate::prover::backend::{Col, Column, ColumnOps};
10+
11+
/// A pool of pre-allocated [`Col<B, BaseField>`] buffers, organized by log_size.
12+
///
13+
/// Used to avoid repeated allocation of evaluation buffers during polynomial commitment.
14+
pub struct BaseColumnPool<B: ColumnOps<BaseField>> {
15+
/// Map from log_size -> stack of available buffers.
16+
pools: HashMap<u32, Vec<Col<B, BaseField>>>,
17+
}
18+
19+
impl<B: ColumnOps<BaseField>> BaseColumnPool<B> {
20+
/// Creates a new empty base column pool.
21+
pub fn new() -> Self {
22+
Self {
23+
pools: HashMap::new(),
24+
}
25+
}
26+
27+
/// Pre-allocates `count` zero-initialized buffers of size `1 << log_size`.
28+
pub fn reserve(&mut self, log_size: u32, count: usize) {
29+
let pool = self.pools.entry(log_size).or_default();
30+
for _ in 0..count {
31+
pool.push(Col::<B, BaseField>::zeros(1 << log_size));
32+
}
33+
}
34+
35+
/// Takes a buffer from the pool for the given `log_size`.
36+
///
37+
/// # Panics
38+
///
39+
/// Panics if no buffer of the requested size is available.
40+
pub fn take(&mut self, log_size: u32) -> Col<B, BaseField> {
41+
self.pools
42+
.get_mut(&log_size)
43+
.and_then(|pool| pool.pop())
44+
.unwrap_or_else(|| {
45+
panic!("BaseColumnPool: no buffer available for log_size={log_size}")
46+
})
47+
}
48+
49+
/// Takes a buffer from the pool, or allocates a new zero-initialized one if none is available.
50+
pub fn take_or_alloc(&mut self, log_size: u32) -> Col<B, BaseField> {
51+
self.pools
52+
.get_mut(&log_size)
53+
.and_then(|pool| pool.pop())
54+
.unwrap_or_else(|| unsafe { Col::<B, BaseField>::uninitialized(1 << log_size) })
55+
}
56+
57+
/// Returns a buffer to the pool. The caller is responsible for ensuring the buffer's log_size
58+
/// matches.
59+
pub fn give_back(&mut self, log_size: u32, buf: Col<B, BaseField>) {
60+
debug_assert_eq!(buf.len(), 1 << log_size);
61+
self.pools.entry(log_size).or_default().push(buf);
62+
}
63+
64+
/// Returns the number of available buffers for a given log_size.
65+
pub fn available(&self, log_size: u32) -> usize {
66+
self.pools.get(&log_size).map_or(0, |pool| pool.len())
67+
}
68+
69+
/// Returns the total number of buffers across all sizes.
70+
pub fn total_available(&self) -> usize {
71+
self.pools.values().map(|pool| pool.len()).sum()
72+
}
73+
}
74+
75+
impl<B: ColumnOps<BaseField>> Default for BaseColumnPool<B> {
76+
fn default() -> Self {
77+
Self::new()
78+
}
79+
}

crates/stwo/src/prover/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub mod channel;
2020
pub mod fri;
2121
pub mod line;
2222
pub mod lookups;
23+
pub mod mempool;
2324
pub mod poly;
2425
pub mod secure_column;
2526
pub mod vcs;

crates/stwo/src/prover/pcs/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::core::ColumnVec;
2020
use crate::prover::air::component_prover::{Poly, Trace, WeightsHashMap};
2121
use crate::prover::backend::{BackendForChannel, Col};
2222
use crate::prover::fri::{FriDecommitResult, FriProver};
23+
use crate::prover::mempool::BaseColumnPool;
2324
use crate::prover::pcs::quotient_ops::compute_fri_quotients;
2425
use crate::prover::poly::circle::{CircleCoefficients, CircleEvaluation};
2526
use crate::prover::poly::twiddles::TwiddleTree;
@@ -34,6 +35,8 @@ pub struct CommitmentSchemeProver<'a, B: BackendForChannel<MC>, MC: MerkleChanne
3435
pub config: PcsConfig,
3536
twiddles: &'a TwiddleTree<B>,
3637
pub store_polynomials_coefficients: bool,
38+
/// Pre-allocated base field column pool for polynomial evaluation during commit.
39+
pub base_column_pool: BaseColumnPool<B>,
3740
}
3841
impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> {
3942
/// Creates a new empty commitment scheme prover with the given configuration and twiddles. The
@@ -44,6 +47,21 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
4447
config,
4548
twiddles,
4649
store_polynomials_coefficients: false,
50+
base_column_pool: BaseColumnPool::new(),
51+
}
52+
}
53+
54+
pub fn with_memory_pool(
55+
config: PcsConfig,
56+
twiddles: &'a TwiddleTree<B>,
57+
base_column_pool: BaseColumnPool<B>,
58+
) -> Self {
59+
CommitmentSchemeProver {
60+
trees: TreeVec::default(),
61+
config,
62+
twiddles,
63+
store_polynomials_coefficients: false,
64+
base_column_pool,
4765
}
4866
}
4967

@@ -62,6 +80,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
6280
self.twiddles,
6381
self.store_polynomials_coefficients,
6482
self.config.lifting_log_size,
83+
&mut self.base_column_pool,
6584
);
6685
self.trees.push(tree);
6786
}
@@ -319,13 +338,15 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
319338
twiddles: &TwiddleTree<B>,
320339
store_polynomials_coefficients: bool,
321340
lifting_log_size: Option<u32>,
341+
base_column_pool: &mut BaseColumnPool<B>,
322342
) -> Self {
323343
let span = span!(Level::INFO, "Extension").entered();
324344
let polynomials = B::evaluate_polynomials(
325345
polynomials,
326346
log_blowup_factor,
327347
twiddles,
328348
store_polynomials_coefficients,
349+
base_column_pool,
329350
);
330351
span.exit();
331352

crates/stwo/src/prover/poly/circle/ops.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::core::poly::circle::{CanonicCoset, CircleDomain};
99
use crate::core::ColumnVec;
1010
use crate::prover::air::component_prover::Poly;
1111
use crate::prover::backend::{Col, ColumnOps};
12+
use crate::prover::mempool::BaseColumnPool;
1213
use crate::prover::poly::twiddles::TwiddleTree;
1314
use crate::prover::poly::BitReversedOrder;
1415

@@ -82,25 +83,43 @@ pub trait PolyOps: ColumnOps<BaseField> + ColumnOps<SecureField> + Sized {
8283
twiddles: &TwiddleTree<Self>,
8384
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
8485

86+
/// Evaluates the polynomial at all points in the domain, writing results into the provided
87+
/// buffer instead of allocating a new one. The buffer must have size `domain.size()`.
88+
fn evaluate_into(
89+
poly: &CircleCoefficients<Self>,
90+
domain: CircleDomain,
91+
twiddles: &TwiddleTree<Self>,
92+
buffer: Col<Self, BaseField>,
93+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
94+
8595
fn evaluate_polynomials(
8696
polynomials: ColumnVec<CircleCoefficients<Self>>,
8797
log_blowup_factor: u32,
8898
twiddles: &TwiddleTree<Self>,
8999
store_polynomials_coefficients: bool,
100+
pool: &mut BaseColumnPool<Self>,
90101
) -> Vec<Poly<Self>>
91102
where
92103
Self: crate::prover::backend::Backend,
93104
{
105+
// Pre-take all buffers from the pool before the parallel section.
106+
let buffers: Vec<_> = polynomials
107+
.iter()
108+
.map(|poly_coeffs| {
109+
let log_eval_size = poly_coeffs.log_size() + log_blowup_factor;
110+
pool.take_or_alloc(log_eval_size)
111+
})
112+
.collect();
113+
94114
#[cfg(feature = "parallel")]
95-
let iter = polynomials.into_par_iter();
115+
let iter = polynomials.into_par_iter().zip(buffers.into_par_iter());
96116
#[cfg(not(feature = "parallel"))]
97-
let iter = polynomials.into_iter();
117+
let iter = polynomials.into_iter().zip(buffers);
98118

99-
iter.map(|poly_coeffs| {
100-
let evals = poly_coeffs.evaluate_with_twiddles(
101-
CanonicCoset::new(poly_coeffs.log_size() + log_blowup_factor).circle_domain(),
102-
twiddles,
103-
);
119+
iter.map(|(poly_coeffs, buffer)| {
120+
let domain =
121+
CanonicCoset::new(poly_coeffs.log_size() + log_blowup_factor).circle_domain();
122+
let evals = Self::evaluate_into(&poly_coeffs, domain, twiddles, buffer);
104123
Poly::new(store_polynomials_coefficients.then_some(poly_coeffs), evals)
105124
})
106125
.collect()

0 commit comments

Comments
 (0)