Skip to content

Commit a2f1b05

Browse files
Use BaseColumnPool in CommitmentSchemeProver::commit to avoid allocation during polynomial evaluation.
Add BaseColumnPool for Col<B, BaseField> buffers, add evaluate_into to PolyOps trait (CPU and SIMD backends), and thread the pool through commit → CommitmentTreeProver::new → evaluate_polynomials. Buffers are pre-taken before parallel iteration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5675eab commit a2f1b05

File tree

5 files changed

+265
-8
lines changed

5 files changed

+265
-8
lines changed

crates/stwo/src/prover/backend/cpu/circle.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,59 @@ impl PolyOps for CpuBackend {
227227
CircleEvaluation::new(domain, values)
228228
}
229229

230+
fn evaluate_into(
231+
poly: &CircleCoefficients<Self>,
232+
domain: CircleDomain,
233+
twiddles: &TwiddleTree<Self>,
234+
mut buffer: Col<Self, BaseField>,
235+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
236+
assert!(domain.half_coset.is_doubling_of(twiddles.root_coset));
237+
assert_eq!(buffer.len(), domain.size());
238+
239+
// Copy extended coefficients into the buffer.
240+
let poly_len = poly.coeffs.len();
241+
buffer[..poly_len].copy_from_slice(&poly.coeffs);
242+
for v in &mut buffer[poly_len..] {
243+
*v = BaseField::zero();
244+
}
245+
246+
if domain.log_size() == 1 {
247+
let (mut v0, mut v1) = (buffer[0], buffer[1]);
248+
butterfly(&mut v0, &mut v1, domain.half_coset.initial.y);
249+
buffer[0] = v0;
250+
buffer[1] = v1;
251+
return CircleEvaluation::new(domain, buffer);
252+
}
253+
254+
if domain.log_size() == 2 {
255+
let (mut v0, mut v1, mut v2, mut v3) = (buffer[0], buffer[1], buffer[2], buffer[3]);
256+
let CirclePoint { x, y } = domain.half_coset.initial;
257+
butterfly(&mut v0, &mut v2, x);
258+
butterfly(&mut v1, &mut v3, x);
259+
butterfly(&mut v0, &mut v1, y);
260+
butterfly(&mut v2, &mut v3, -y);
261+
buffer[0] = v0;
262+
buffer[1] = v1;
263+
buffer[2] = v2;
264+
buffer[3] = v3;
265+
return CircleEvaluation::new(domain, buffer);
266+
}
267+
268+
let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
269+
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);
270+
271+
for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() {
272+
for (h, &t) in layer_twiddles.iter().enumerate() {
273+
fft_layer_loop(&mut buffer, layer + 1, h, t, butterfly);
274+
}
275+
}
276+
for (h, t) in circle_twiddles.enumerate() {
277+
fft_layer_loop(&mut buffer, 0, h, t, butterfly);
278+
}
279+
280+
CircleEvaluation::new(domain, buffer)
281+
}
282+
230283
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
231284
const CHUNK_LOG_SIZE: usize = 12;
232285
const CHUNK_SIZE: usize = 1 << CHUNK_LOG_SIZE;

crates/stwo/src/prover/backend/simd/circle.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,63 @@ impl PolyOps for SimdBackend {
442442
)
443443
}
444444

445+
fn evaluate_into(
446+
poly: &CircleCoefficients<Self>,
447+
domain: CircleDomain,
448+
twiddles: &TwiddleTree<Self>,
449+
mut buffer: Col<Self, BaseField>,
450+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
451+
let _span = span!(Level::TRACE, "", class = "rFFT").entered();
452+
let log_size = domain.log_size();
453+
let fft_log_size = poly.log_size();
454+
assert!(
455+
log_size >= fft_log_size,
456+
"Can only evaluate on larger domains"
457+
);
458+
assert_eq!(buffer.len(), domain.size());
459+
460+
if fft_log_size < MIN_FFT_LOG_SIZE {
461+
let cpu_poly: CircleCoefficients<CpuBackend> =
462+
CircleCoefficients::new(poly.coeffs.to_cpu());
463+
let cpu_eval = cpu_poly.evaluate(domain);
464+
return CircleEvaluation::new(
465+
cpu_eval.domain,
466+
Col::<SimdBackend, BaseField>::from_iter(cpu_eval.values),
467+
);
468+
}
469+
470+
let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
471+
472+
// Evaluate on big domains by evaluating on several subdomains.
473+
let log_subdomains = log_size - fft_log_size;
474+
475+
for i in 0..(1 << log_subdomains) {
476+
// The subdomain twiddles are a slice of the large domain twiddles.
477+
let subdomain_twiddles = (0..(fft_log_size - 1))
478+
.map(|layer_i| {
479+
&twiddles[layer_i as usize]
480+
[i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)]
481+
})
482+
.collect::<Vec<_>>();
483+
484+
// FFT from the coefficients buffer directly into the provided buffer.
485+
unsafe {
486+
rfft::fft(
487+
transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()),
488+
transmute::<*mut PackedBaseField, *mut u32>(
489+
buffer.data[i << (fft_log_size - LOG_N_LANES)
490+
..(i + 1) << (fft_log_size - LOG_N_LANES)]
491+
.as_mut_ptr(),
492+
),
493+
&subdomain_twiddles,
494+
fft_log_size as usize,
495+
);
496+
}
497+
}
498+
499+
CircleEvaluation::new(domain, buffer)
500+
}
501+
445502
/// Precomputes the (doubled) twiddles for a given coset tower.
446503
/// The twiddles are the x values of each coset in bit-reversed order.
447504
/// Note: the coset point are symmetrical over the x-axis so only the first half of the coset is

crates/stwo/src/prover/mempool.rs

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use std::collections::HashMap;
88

99
use crate::core::fields::m31::BaseField;
10-
use crate::prover::backend::ColumnOps;
10+
use crate::prover::backend::{Col, Column, ColumnOps};
1111
use crate::prover::secure_column::SecureColumnByCoords;
1212

1313
/// A pool of pre-allocated [`SecureColumnByCoords`] buffers, organized by log_size.
@@ -87,6 +87,76 @@ impl<B: ColumnOps<BaseField>> Default for ColumnPool<B> {
8787
}
8888
}
8989

90+
/// A pool of pre-allocated [`Col<B, BaseField>`] buffers, organized by log_size.
91+
///
92+
/// Used to avoid repeated allocation of evaluation buffers during polynomial commitment.
93+
pub struct BaseColumnPool<B: ColumnOps<BaseField>> {
94+
/// Map from log_size -> stack of available buffers.
95+
pools: HashMap<u32, Vec<Col<B, BaseField>>>,
96+
}
97+
98+
impl<B: ColumnOps<BaseField>> BaseColumnPool<B> {
99+
/// Creates a new empty base column pool.
100+
pub fn new() -> Self {
101+
Self {
102+
pools: HashMap::new(),
103+
}
104+
}
105+
106+
/// Pre-allocates `count` zero-initialized buffers of size `1 << log_size`.
107+
pub fn reserve(&mut self, log_size: u32, count: usize) {
108+
let pool = self.pools.entry(log_size).or_default();
109+
for _ in 0..count {
110+
pool.push(Col::<B, BaseField>::zeros(1 << log_size));
111+
}
112+
}
113+
114+
/// Takes a buffer from the pool for the given `log_size`.
115+
///
116+
/// # Panics
117+
///
118+
/// Panics if no buffer of the requested size is available.
119+
pub fn take(&mut self, log_size: u32) -> Col<B, BaseField> {
120+
self.pools
121+
.get_mut(&log_size)
122+
.and_then(|pool| pool.pop())
123+
.unwrap_or_else(|| {
124+
panic!("BaseColumnPool: no buffer available for log_size={log_size}")
125+
})
126+
}
127+
128+
/// Takes a buffer from the pool, or allocates a new zero-initialized one if none is available.
129+
pub fn take_or_alloc(&mut self, log_size: u32) -> Col<B, BaseField> {
130+
self.pools
131+
.get_mut(&log_size)
132+
.and_then(|pool| pool.pop())
133+
.unwrap_or_else(|| Col::<B, BaseField>::zeros(1 << log_size))
134+
}
135+
136+
/// Returns a buffer to the pool. The caller is responsible for ensuring the buffer's log_size
137+
/// matches.
138+
pub fn give_back(&mut self, log_size: u32, buf: Col<B, BaseField>) {
139+
debug_assert_eq!(buf.len(), 1 << log_size);
140+
self.pools.entry(log_size).or_default().push(buf);
141+
}
142+
143+
/// Returns the number of available buffers for a given log_size.
144+
pub fn available(&self, log_size: u32) -> usize {
145+
self.pools.get(&log_size).map_or(0, |pool| pool.len())
146+
}
147+
148+
/// Returns the total number of buffers across all sizes.
149+
pub fn total_available(&self) -> usize {
150+
self.pools.values().map(|pool| pool.len()).sum()
151+
}
152+
}
153+
154+
impl<B: ColumnOps<BaseField>> Default for BaseColumnPool<B> {
155+
fn default() -> Self {
156+
Self::new()
157+
}
158+
}
159+
90160
/// Zeroes out all columns in a [`SecureColumnByCoords`].
91161
fn zero_secure_column<B: ColumnOps<BaseField>>(col: &mut SecureColumnByCoords<B>) {
92162
let len = col.len();
@@ -101,13 +171,16 @@ fn zero_secure_column<B: ColumnOps<BaseField>>(col: &mut SecureColumnByCoords<B>
101171
pub struct ProverMemPool<B: ColumnOps<BaseField>> {
102172
/// Pool of reusable [`SecureColumnByCoords`] buffers.
103173
pub column_pool: ColumnPool<B>,
174+
/// Pool of reusable base field column buffers.
175+
pub base_column_pool: BaseColumnPool<B>,
104176
}
105177

106178
impl<B: ColumnOps<BaseField>> ProverMemPool<B> {
107179
/// Creates a new workspace with an empty column pool.
108180
pub fn new() -> Self {
109181
Self {
110182
column_pool: ColumnPool::new(),
183+
base_column_pool: BaseColumnPool::new(),
111184
}
112185
}
113186

@@ -203,4 +276,49 @@ mod tests {
203276
}
204277

205278
use num_traits::Zero;
279+
280+
#[test]
281+
fn test_base_column_pool_reserve_and_take() {
282+
let mut pool = BaseColumnPool::<CpuBackend>::new();
283+
pool.reserve(4, 3);
284+
assert_eq!(pool.available(4), 3);
285+
286+
let buf = pool.take(4);
287+
assert_eq!(buf.len(), 1 << 4);
288+
assert_eq!(pool.available(4), 2);
289+
}
290+
291+
#[test]
292+
fn test_base_column_pool_give_back() {
293+
let mut pool = BaseColumnPool::<CpuBackend>::new();
294+
pool.reserve(5, 1);
295+
let buf = pool.take(5);
296+
assert_eq!(pool.available(5), 0);
297+
298+
pool.give_back(5, buf);
299+
assert_eq!(pool.available(5), 1);
300+
}
301+
302+
#[test]
303+
fn test_base_column_pool_take_or_alloc() {
304+
let mut pool = BaseColumnPool::<CpuBackend>::new();
305+
306+
// No pre-allocated buffer, should allocate.
307+
let buf = pool.take_or_alloc(3);
308+
assert_eq!(buf.len(), 1 << 3);
309+
assert_eq!(pool.available(3), 0);
310+
311+
// Return and take again.
312+
pool.give_back(3, buf);
313+
assert_eq!(pool.available(3), 1);
314+
let _buf = pool.take_or_alloc(3);
315+
assert_eq!(pool.available(3), 0);
316+
}
317+
318+
#[test]
319+
#[should_panic(expected = "no buffer available")]
320+
fn test_base_column_pool_take_panics_when_empty() {
321+
let mut pool = BaseColumnPool::<CpuBackend>::new();
322+
pool.take(4);
323+
}
206324
}

crates/stwo/src/prover/pcs/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::core::ColumnVec;
2020
use crate::prover::air::component_prover::{Poly, Trace, WeightsHashMap};
2121
use crate::prover::backend::{BackendForChannel, Col};
2222
use crate::prover::fri::{FriDecommitResult, FriProver};
23+
use crate::prover::mempool::{BaseColumnPool, ColumnPool};
2324
use crate::prover::pcs::quotient_ops::compute_fri_quotients;
2425
use crate::prover::poly::circle::{CircleCoefficients, CircleEvaluation};
2526
use crate::prover::poly::twiddles::TwiddleTree;
@@ -34,6 +35,10 @@ pub struct CommitmentSchemeProver<'a, B: BackendForChannel<MC>, MC: MerkleChanne
3435
pub config: PcsConfig,
3536
twiddles: &'a TwiddleTree<B>,
3637
pub store_polynomials_coefficients: bool,
38+
/// Pre-allocated memory pool for the proving pipeline.
39+
pub mempool: ColumnPool<B>,
40+
/// Pre-allocated base field column pool for polynomial evaluation during commit.
41+
pub base_column_pool: BaseColumnPool<B>,
3742
}
3843
impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> {
3944
/// Creates a new empty commitment scheme prover with the given configuration and twiddles. The
@@ -44,6 +49,8 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
4449
config,
4550
twiddles,
4651
store_polynomials_coefficients: false,
52+
mempool: ColumnPool::new(),
53+
base_column_pool: BaseColumnPool::new(),
4754
}
4855
}
4956

@@ -62,6 +69,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
6269
self.twiddles,
6370
self.store_polynomials_coefficients,
6471
self.config.lifting_log_size,
72+
&mut self.base_column_pool,
6573
);
6674
self.trees.push(tree);
6775
}
@@ -319,13 +327,15 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
319327
twiddles: &TwiddleTree<B>,
320328
store_polynomials_coefficients: bool,
321329
lifting_log_size: Option<u32>,
330+
base_column_pool: &mut BaseColumnPool<B>,
322331
) -> Self {
323332
let span = span!(Level::INFO, "Extension").entered();
324333
let polynomials = B::evaluate_polynomials(
325334
polynomials,
326335
log_blowup_factor,
327336
twiddles,
328337
store_polynomials_coefficients,
338+
base_column_pool,
329339
);
330340
span.exit();
331341

crates/stwo/src/prover/poly/circle/ops.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::core::poly::circle::{CanonicCoset, CircleDomain};
99
use crate::core::ColumnVec;
1010
use crate::prover::air::component_prover::Poly;
1111
use crate::prover::backend::{Col, ColumnOps};
12+
use crate::prover::mempool::BaseColumnPool;
1213
use crate::prover::poly::twiddles::TwiddleTree;
1314
use crate::prover::poly::BitReversedOrder;
1415

@@ -82,25 +83,43 @@ pub trait PolyOps: ColumnOps<BaseField> + ColumnOps<SecureField> + Sized {
8283
twiddles: &TwiddleTree<Self>,
8384
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
8485

86+
/// Evaluates the polynomial at all points in the domain, writing results into the provided
87+
/// buffer instead of allocating a new one. The buffer must have size `domain.size()`.
88+
fn evaluate_into(
89+
poly: &CircleCoefficients<Self>,
90+
domain: CircleDomain,
91+
twiddles: &TwiddleTree<Self>,
92+
buffer: Col<Self, BaseField>,
93+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
94+
8595
fn evaluate_polynomials(
8696
polynomials: ColumnVec<CircleCoefficients<Self>>,
8797
log_blowup_factor: u32,
8898
twiddles: &TwiddleTree<Self>,
8999
store_polynomials_coefficients: bool,
100+
pool: &mut BaseColumnPool<Self>,
90101
) -> Vec<Poly<Self>>
91102
where
92103
Self: crate::prover::backend::Backend,
93104
{
105+
// Pre-take all buffers from the pool before the parallel section.
106+
let buffers: Vec<_> = polynomials
107+
.iter()
108+
.map(|poly_coeffs| {
109+
let log_eval_size = poly_coeffs.log_size() + log_blowup_factor;
110+
pool.take_or_alloc(log_eval_size)
111+
})
112+
.collect();
113+
94114
#[cfg(feature = "parallel")]
95-
let iter = polynomials.into_par_iter();
115+
let iter = polynomials.into_par_iter().zip(buffers.into_par_iter());
96116
#[cfg(not(feature = "parallel"))]
97-
let iter = polynomials.into_iter();
117+
let iter = polynomials.into_iter().zip(buffers.into_iter());
98118

99-
iter.map(|poly_coeffs| {
100-
let evals = poly_coeffs.evaluate_with_twiddles(
101-
CanonicCoset::new(poly_coeffs.log_size() + log_blowup_factor).circle_domain(),
102-
twiddles,
103-
);
119+
iter.map(|(poly_coeffs, buffer)| {
120+
let domain =
121+
CanonicCoset::new(poly_coeffs.log_size() + log_blowup_factor).circle_domain();
122+
let evals = Self::evaluate_into(&poly_coeffs, domain, twiddles, buffer);
104123
Poly::new(store_polynomials_coefficients.then_some(poly_coeffs), evals)
105124
})
106125
.collect()

0 commit comments

Comments
 (0)