Skip to content

Commit c8e2ac3

Browse files
Use BaseColumnPool in CommitmentSchemeProver::commit to avoid allocation during polynomial evaluation.
Add BaseColumnPool for Col<B, BaseField> buffers, add evaluate_into to PolyOps trait (CPU and SIMD backends), and thread the pool through commit → CommitmentTreeProver::new → evaluate_polynomials. Buffers are pre-taken before parallel iteration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9593293 commit c8e2ac3

File tree

8 files changed

+205
-32
lines changed

8 files changed

+205
-32
lines changed

crates/stwo/benches/pcs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use stwo::core::poly::circle::CanonicCoset;
99
use stwo::core::vcs_lifted::blake2_merkle::Blake2sMerkleChannel;
1010
use stwo::prover::backend::simd::SimdBackend;
1111
use stwo::prover::backend::{BackendForChannel, CpuBackend};
12+
use stwo::prover::mempool::BaseColumnPool;
1213
use stwo::prover::poly::circle::CircleEvaluation;
1314
use stwo::prover::poly::twiddles::TwiddleTree;
1415
use stwo::prover::poly::BitReversedOrder;
@@ -35,6 +36,7 @@ fn benched_fn<B: BackendForChannel<Blake2sMerkleChannel>>(
3536
twiddles,
3637
false,
3738
None,
39+
&mut BaseColumnPool::new(),
3840
);
3941
}
4042

crates/stwo/src/core/utils.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,38 @@
11
use core::iter::Peekable;
2+
use core::ops::{Deref, DerefMut};
23

34
use std_shims::Vec;
45

56
use super::fields::Field;
67

8+
/// An enum that either borrows mutably or owns a value.
9+
/// Useful when a struct can optionally receive an external `&mut T` but also needs a fallback owned
10+
/// instance.
11+
pub enum MaybeOwned<'a, T> {
12+
Borrowed(&'a mut T),
13+
Owned(T),
14+
}
15+
16+
impl<T> Deref for MaybeOwned<'_, T> {
17+
type Target = T;
18+
19+
fn deref(&self) -> &T {
20+
match self {
21+
MaybeOwned::Borrowed(r) => r,
22+
MaybeOwned::Owned(ref v) => v,
23+
}
24+
}
25+
}
26+
27+
impl<T> DerefMut for MaybeOwned<'_, T> {
28+
fn deref_mut(&mut self) -> &mut T {
29+
match self {
30+
MaybeOwned::Borrowed(r) => r,
31+
MaybeOwned::Owned(ref mut v) => v,
32+
}
33+
}
34+
}
35+
736
pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
837
fn assign(self, other: impl IntoIterator<Item = T>)
938
where

crates/stwo/src/prover/backend/cpu/circle.rs

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,40 +191,62 @@ impl PolyOps for CpuBackend {
191191
poly: &CircleCoefficients<Self>,
192192
domain: CircleDomain,
193193
twiddles: &TwiddleTree<Self>,
194+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
195+
let buffer = vec![BaseField::zero(); domain.size()];
196+
Self::evaluate_into(poly, domain, twiddles, buffer)
197+
}
198+
199+
fn evaluate_into(
200+
poly: &CircleCoefficients<Self>,
201+
domain: CircleDomain,
202+
twiddles: &TwiddleTree<Self>,
203+
mut buffer: Col<Self, BaseField>,
194204
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
195205
assert!(domain.half_coset.is_doubling_of(twiddles.root_coset));
206+
assert_eq!(buffer.len(), domain.size());
196207

197-
let mut values = poly.extend(domain.log_size()).coeffs;
208+
// Copy extended coefficients into the buffer.
209+
let poly_len = poly.coeffs.len();
210+
buffer[..poly_len].copy_from_slice(&poly.coeffs);
211+
for v in &mut buffer[poly_len..] {
212+
*v = BaseField::zero();
213+
}
198214

199215
if domain.log_size() == 1 {
200-
let (mut v0, mut v1) = (values[0], values[1]);
216+
let (mut v0, mut v1) = (buffer[0], buffer[1]);
201217
butterfly(&mut v0, &mut v1, domain.half_coset.initial.y);
202-
return CircleEvaluation::new(domain, vec![v0, v1]);
218+
buffer[0] = v0;
219+
buffer[1] = v1;
220+
return CircleEvaluation::new(domain, buffer);
203221
}
204222

205223
if domain.log_size() == 2 {
206-
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
224+
let (mut v0, mut v1, mut v2, mut v3) = (buffer[0], buffer[1], buffer[2], buffer[3]);
207225
let CirclePoint { x, y } = domain.half_coset.initial;
208226
butterfly(&mut v0, &mut v2, x);
209227
butterfly(&mut v1, &mut v3, x);
210228
butterfly(&mut v0, &mut v1, y);
211229
butterfly(&mut v2, &mut v3, -y);
212-
return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]);
230+
buffer[0] = v0;
231+
buffer[1] = v1;
232+
buffer[2] = v2;
233+
buffer[3] = v3;
234+
return CircleEvaluation::new(domain, buffer);
213235
}
214236

215237
let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
216238
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);
217239

218240
for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() {
219241
for (h, &t) in layer_twiddles.iter().enumerate() {
220-
fft_layer_loop(&mut values, layer + 1, h, t, butterfly);
242+
fft_layer_loop(&mut buffer, layer + 1, h, t, butterfly);
221243
}
222244
}
223245
for (h, t) in circle_twiddles.enumerate() {
224-
fft_layer_loop(&mut values, 0, h, t, butterfly);
246+
fft_layer_loop(&mut buffer, 0, h, t, butterfly);
225247
}
226248

227-
CircleEvaluation::new(domain, values)
249+
CircleEvaluation::new(domain, buffer)
228250
}
229251

230252
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {

crates/stwo/src/prover/backend/simd/circle.rs

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,17 @@ impl PolyOps for SimdBackend {
378378
poly: &CircleCoefficients<Self>,
379379
domain: CircleDomain,
380380
twiddles: &TwiddleTree<Self>,
381+
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
382+
// SAFETY: evaluate_into writes all values via FFT before they are read.
383+
let buffer = unsafe { Col::<Self, BaseField>::uninitialized(domain.size()) };
384+
Self::evaluate_into(poly, domain, twiddles, buffer)
385+
}
386+
387+
fn evaluate_into(
388+
poly: &CircleCoefficients<Self>,
389+
domain: CircleDomain,
390+
twiddles: &TwiddleTree<Self>,
391+
mut buffer: Col<Self, BaseField>,
381392
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
382393
let _span = span!(Level::TRACE, "", class = "rFFT").entered();
383394
let log_size = domain.log_size();
@@ -386,6 +397,7 @@ impl PolyOps for SimdBackend {
386397
log_size >= fft_log_size,
387398
"Can only evaluate on larger domains"
388399
);
400+
assert_eq!(buffer.len(), domain.size());
389401

390402
if fft_log_size < MIN_FFT_LOG_SIZE {
391403
let cpu_poly: CircleCoefficients<CpuBackend> =
@@ -399,16 +411,9 @@ impl PolyOps for SimdBackend {
399411

400412
let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
401413

402-
// Evaluate on a big domains by evaluating on several subdomains.
414+
// Evaluate on big domains by evaluating on several subdomains.
403415
let log_subdomains = log_size - fft_log_size;
404416

405-
// Allocate the destination buffer without initializing.
406-
let mut values = Vec::with_capacity(domain.size() >> LOG_N_LANES);
407-
#[allow(clippy::uninit_vec)]
408-
unsafe {
409-
values.set_len(domain.size() >> LOG_N_LANES)
410-
};
411-
412417
for i in 0..(1 << log_subdomains) {
413418
// The subdomain twiddles are a slice of the large domain twiddles.
414419
let subdomain_twiddles = (0..(fft_log_size - 1))
@@ -418,12 +423,12 @@ impl PolyOps for SimdBackend {
418423
})
419424
.collect::<Vec<_>>();
420425

421-
// FFT from the coefficients buffer to the values chunk.
426+
// FFT from the coefficients buffer directly into the provided buffer.
422427
unsafe {
423428
rfft::fft(
424429
transmute::<*const PackedBaseField, *const u32>(poly.coeffs.data.as_ptr()),
425430
transmute::<*mut PackedBaseField, *mut u32>(
426-
values[i << (fft_log_size - LOG_N_LANES)
431+
buffer.data[i << (fft_log_size - LOG_N_LANES)
427432
..(i + 1) << (fft_log_size - LOG_N_LANES)]
428433
.as_mut_ptr(),
429434
),
@@ -433,13 +438,7 @@ impl PolyOps for SimdBackend {
433438
}
434439
}
435440

436-
CircleEvaluation::new(
437-
domain,
438-
BaseColumn {
439-
data: values,
440-
length: domain.size(),
441-
},
442-
)
441+
CircleEvaluation::new(domain, buffer)
443442
}
444443

445444
/// Precomputes the (doubled) twiddles for a given coset tower.

crates/stwo/src/prover/mempool.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//! Pre-allocated memory pools for the proving pipeline.
2+
//!
3+
//! The [`BaseColumnPool`] manages reusable [`Col<B, BaseField>`] buffers for polynomial evaluation,
4+
//! avoiding repeated allocation/deallocation of large column buffers during proving.
5+
6+
use std::collections::HashMap;
7+
8+
use crate::core::fields::m31::BaseField;
9+
use crate::prover::backend::{Col, Column, ColumnOps};
10+
11+
/// A pool of pre-allocated [`Col<B, BaseField>`] buffers, organized by log_size.
12+
///
13+
/// Used to avoid repeated allocation of evaluation buffers during polynomial commitment.
14+
pub struct BaseColumnPool<B: ColumnOps<BaseField>> {
15+
/// Map from log_size -> stack of available buffers.
16+
pools: HashMap<u32, Vec<Col<B, BaseField>>>,
17+
}
18+
19+
impl<B: ColumnOps<BaseField>> BaseColumnPool<B> {
20+
/// Creates a new empty base column pool.
21+
pub fn new() -> Self {
22+
Self {
23+
pools: HashMap::new(),
24+
}
25+
}
26+
27+
/// Pre-allocates `count` zero-initialized buffers of size `1 << log_size`.
28+
pub fn reserve(&mut self, log_size: u32, count: usize) {
29+
let pool = self.pools.entry(log_size).or_default();
30+
for _ in 0..count {
31+
pool.push(Col::<B, BaseField>::zeros(1 << log_size));
32+
}
33+
}
34+
35+
/// Takes a buffer from the pool for the given `log_size`.
36+
///
37+
/// # Panics
38+
///
39+
/// Panics if no buffer of the requested size is available.
40+
pub fn take(&mut self, log_size: u32) -> Col<B, BaseField> {
41+
self.pools
42+
.get_mut(&log_size)
43+
.and_then(|pool| pool.pop())
44+
.unwrap_or_else(|| {
45+
panic!("BaseColumnPool: no buffer available for log_size={log_size}")
46+
})
47+
}
48+
49+
/// Takes a buffer from the pool, or allocates a new zero-initialized one if none is available.
50+
pub fn take_or_alloc(&mut self, log_size: u32) -> Col<B, BaseField> {
51+
self.pools
52+
.get_mut(&log_size)
53+
.and_then(|pool| pool.pop())
54+
.unwrap_or_else(|| unsafe { Col::<B, BaseField>::uninitialized(1 << log_size) })
55+
}
56+
57+
/// Returns a buffer to the pool. The caller is responsible for ensuring the buffer's log_size
58+
/// matches.
59+
pub fn give_back(&mut self, log_size: u32, buf: Col<B, BaseField>) {
60+
debug_assert_eq!(buf.len(), 1 << log_size);
61+
self.pools.entry(log_size).or_default().push(buf);
62+
}
63+
64+
/// Returns the number of available buffers for a given log_size.
65+
pub fn available(&self, log_size: u32) -> usize {
66+
self.pools.get(&log_size).map_or(0, |pool| pool.len())
67+
}
68+
69+
/// Returns the total number of buffers across all sizes.
70+
pub fn total_available(&self) -> usize {
71+
self.pools.values().map(|pool| pool.len()).sum()
72+
}
73+
}
74+
75+
impl<B: ColumnOps<BaseField>> Default for BaseColumnPool<B> {
76+
fn default() -> Self {
77+
Self::new()
78+
}
79+
}

crates/stwo/src/prover/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub mod channel;
2020
pub mod fri;
2121
pub mod line;
2222
pub mod lookups;
23+
pub mod mempool;
2324
pub mod poly;
2425
pub mod secure_column;
2526
pub mod vcs;

crates/stwo/src/prover/pcs/mod.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ use crate::core::ColumnVec;
2020
use crate::prover::air::component_prover::{Poly, Trace, WeightsHashMap};
2121
use crate::prover::backend::{BackendForChannel, Col};
2222
use crate::prover::fri::{FriDecommitResult, FriProver};
23+
use crate::core::utils::MaybeOwned;
24+
use crate::prover::mempool::BaseColumnPool;
2325
use crate::prover::pcs::quotient_ops::compute_fri_quotients;
2426
use crate::prover::poly::circle::{CircleCoefficients, CircleEvaluation};
2527
use crate::prover::poly::twiddles::TwiddleTree;
@@ -34,6 +36,8 @@ pub struct CommitmentSchemeProver<'a, B: BackendForChannel<MC>, MC: MerkleChanne
3436
pub config: PcsConfig,
3537
twiddles: &'a TwiddleTree<B>,
3638
pub store_polynomials_coefficients: bool,
39+
/// Pre-allocated base field column pool for polynomial evaluation during commit.
40+
pub base_column_pool: MaybeOwned<'a, BaseColumnPool<B>>,
3741
}
3842
impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> {
3943
/// Creates a new empty commitment scheme prover with the given configuration and twiddles. The
@@ -44,6 +48,21 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
4448
config,
4549
twiddles,
4650
store_polynomials_coefficients: false,
51+
base_column_pool: MaybeOwned::Owned(BaseColumnPool::new()),
52+
}
53+
}
54+
55+
pub fn with_memory_pool(
56+
config: PcsConfig,
57+
twiddles: &'a TwiddleTree<B>,
58+
base_column_pool: &'a mut BaseColumnPool<B>,
59+
) -> Self {
60+
CommitmentSchemeProver {
61+
trees: TreeVec::default(),
62+
config,
63+
twiddles,
64+
store_polynomials_coefficients: false,
65+
base_column_pool: MaybeOwned::Borrowed(base_column_pool),
4766
}
4867
}
4968

@@ -62,6 +81,7 @@ impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a,
6281
self.twiddles,
6382
self.store_polynomials_coefficients,
6483
self.config.lifting_log_size,
84+
&mut self.base_column_pool,
6585
);
6686
self.trees.push(tree);
6787
}
@@ -319,13 +339,15 @@ impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
319339
twiddles: &TwiddleTree<B>,
320340
store_polynomials_coefficients: bool,
321341
lifting_log_size: Option<u32>,
342+
base_column_pool: &mut BaseColumnPool<B>,
322343
) -> Self {
323344
let span = span!(Level::INFO, "Extension").entered();
324345
let polynomials = B::evaluate_polynomials(
325346
polynomials,
326347
log_blowup_factor,
327348
twiddles,
328349
store_polynomials_coefficients,
350+
base_column_pool,
329351
);
330352
span.exit();
331353

0 commit comments

Comments
 (0)