Skip to content

Commit 5675eab

Browse files
Introduce ColumnPool and ProverMemPool for buffer reuse.
Add a memory pooling system that manages reusable SecureColumnByCoords buffers organized by log_size, avoiding repeated allocation/deallocation of large column buffers during proving. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9593293 commit 5675eab

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

crates/stwo/src/prover/mempool.rs

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
//! Pre-allocated memory pools for the proving pipeline.
2+
//!
3+
//! The [`ColumnPool`] manages reusable [`SecureColumnByCoords`] buffers organized by log_size,
4+
//! avoiding repeated allocation/deallocation of large column buffers during proving.
5+
//! The [`ProverMemPool`] pre-allocates all needed buffers upfront based on the proving workload.
6+
7+
use std::collections::HashMap;
8+
9+
use crate::core::fields::m31::BaseField;
10+
use crate::prover::backend::ColumnOps;
11+
use crate::prover::secure_column::SecureColumnByCoords;
12+
13+
/// A pool of pre-allocated [`SecureColumnByCoords`] buffers, organized by log_size.
14+
pub struct ColumnPool<B: ColumnOps<BaseField>> {
15+
/// Map from log_size -> stack of available buffers.
16+
pools: HashMap<u32, Vec<SecureColumnByCoords<B>>>,
17+
}
18+
19+
impl<B: ColumnOps<BaseField>> ColumnPool<B> {
20+
/// Creates a new empty column pool.
21+
pub fn new() -> Self {
22+
Self {
23+
pools: HashMap::new(),
24+
}
25+
}
26+
27+
/// Pre-allocates `count` zero-initialized buffers of size `1 << log_size`.
28+
pub fn reserve(&mut self, log_size: u32, count: usize) {
29+
let pool = self.pools.entry(log_size).or_default();
30+
for _ in 0..count {
31+
pool.push(SecureColumnByCoords::zeros(1 << log_size));
32+
}
33+
}
34+
35+
/// Takes a buffer from the pool for the given `log_size`.
36+
///
37+
/// # Panics
38+
///
39+
/// Panics if no buffer of the requested size is available.
40+
pub fn take(&mut self, log_size: u32) -> SecureColumnByCoords<B> {
41+
self.pools
42+
.get_mut(&log_size)
43+
.and_then(|pool| pool.pop())
44+
.unwrap_or_else(|| panic!("ColumnPool: no buffer available for log_size={log_size}"))
45+
}
46+
47+
/// Takes a buffer from the pool, or allocates a new zero-initialized one if none is available.
48+
pub fn take_or_alloc(&mut self, log_size: u32) -> SecureColumnByCoords<B> {
49+
self.pools
50+
.get_mut(&log_size)
51+
.and_then(|pool| pool.pop())
52+
.unwrap_or_else(|| SecureColumnByCoords::zeros(1 << log_size))
53+
}
54+
55+
/// Returns a buffer to the pool. The caller is responsible for ensuring the buffer's log_size
56+
/// matches.
57+
pub fn give_back(&mut self, log_size: u32, buf: SecureColumnByCoords<B>) {
58+
debug_assert_eq!(buf.len(), 1 << log_size);
59+
self.pools.entry(log_size).or_default().push(buf);
60+
}
61+
62+
/// Takes a buffer from the pool, zeroing it before returning. Falls back to allocating a new
63+
/// zero-initialized buffer if none is available.
64+
pub fn take_zeroed(&mut self, log_size: u32) -> SecureColumnByCoords<B> {
65+
if let Some(mut buf) = self.pools.get_mut(&log_size).and_then(|pool| pool.pop()) {
66+
zero_secure_column(&mut buf);
67+
buf
68+
} else {
69+
SecureColumnByCoords::zeros(1 << log_size)
70+
}
71+
}
72+
73+
/// Returns the number of available buffers for a given log_size.
74+
pub fn available(&self, log_size: u32) -> usize {
75+
self.pools.get(&log_size).map_or(0, |pool| pool.len())
76+
}
77+
78+
/// Returns the total number of buffers across all sizes.
79+
pub fn total_available(&self) -> usize {
80+
self.pools.values().map(|pool| pool.len()).sum()
81+
}
82+
}
83+
84+
impl<B: ColumnOps<BaseField>> Default for ColumnPool<B> {
85+
fn default() -> Self {
86+
Self::new()
87+
}
88+
}
89+
90+
/// Zeroes out all columns in a [`SecureColumnByCoords`].
91+
fn zero_secure_column<B: ColumnOps<BaseField>>(col: &mut SecureColumnByCoords<B>) {
92+
let len = col.len();
93+
*col = SecureColumnByCoords::zeros(len);
94+
}
95+
96+
/// Pre-allocated memory for the entire proving pipeline.
97+
///
98+
/// Created by analyzing the component structure and PCS configuration before proving begins.
99+
/// Contains pools of reusable column buffers that various parts of the prover can draw from
100+
/// instead of allocating on-demand.
101+
pub struct ProverMemPool<B: ColumnOps<BaseField>> {
102+
/// Pool of reusable [`SecureColumnByCoords`] buffers.
103+
pub column_pool: ColumnPool<B>,
104+
}
105+
106+
impl<B: ColumnOps<BaseField>> ProverMemPool<B> {
107+
/// Creates a new workspace with an empty column pool.
108+
pub fn new() -> Self {
109+
Self {
110+
column_pool: ColumnPool::new(),
111+
}
112+
}
113+
114+
/// Creates a workspace with pre-allocated buffers based on the specified requirements.
115+
///
116+
/// `requirements` is a list of `(log_size, count)` pairs indicating how many buffers of each
117+
/// size to pre-allocate.
118+
pub fn with_requirements(requirements: &[(u32, usize)]) -> Self {
119+
let mut workspace = Self::new();
120+
for &(log_size, count) in requirements {
121+
workspace.column_pool.reserve(log_size, count);
122+
}
123+
workspace
124+
}
125+
}
126+
127+
impl<B: ColumnOps<BaseField>> Default for ProverMemPool<B> {
128+
fn default() -> Self {
129+
Self::new()
130+
}
131+
}
132+
133+
#[cfg(test)]
134+
mod tests {
135+
use super::*;
136+
use crate::prover::backend::CpuBackend;
137+
138+
#[test]
139+
fn test_column_pool_reserve_and_take() {
140+
let mut pool = ColumnPool::<CpuBackend>::new();
141+
pool.reserve(4, 3);
142+
assert_eq!(pool.available(4), 3);
143+
144+
let buf = pool.take(4);
145+
assert_eq!(buf.len(), 1 << 4);
146+
assert_eq!(pool.available(4), 2);
147+
}
148+
149+
#[test]
150+
fn test_column_pool_give_back() {
151+
let mut pool = ColumnPool::<CpuBackend>::new();
152+
pool.reserve(5, 1);
153+
let buf = pool.take(5);
154+
assert_eq!(pool.available(5), 0);
155+
156+
pool.give_back(5, buf);
157+
assert_eq!(pool.available(5), 1);
158+
}
159+
160+
#[test]
161+
fn test_column_pool_take_or_alloc() {
162+
let mut pool = ColumnPool::<CpuBackend>::new();
163+
164+
// No pre-allocated buffer, should allocate.
165+
let buf = pool.take_or_alloc(3);
166+
assert_eq!(buf.len(), 1 << 3);
167+
assert_eq!(pool.available(3), 0);
168+
169+
// Return and take again.
170+
pool.give_back(3, buf);
171+
assert_eq!(pool.available(3), 1);
172+
let _buf = pool.take_or_alloc(3);
173+
assert_eq!(pool.available(3), 0);
174+
}
175+
176+
#[test]
177+
fn test_column_pool_take_zeroed() {
178+
let mut pool = ColumnPool::<CpuBackend>::new();
179+
pool.reserve(4, 1);
180+
181+
let buf = pool.take_zeroed(4);
182+
assert_eq!(buf.len(), 1 << 4);
183+
// Verify all values are zero.
184+
for i in 0..buf.len() {
185+
assert!(buf.at(i).is_zero(), "non-zero at index {i}");
186+
}
187+
}
188+
189+
#[test]
190+
#[should_panic(expected = "no buffer available")]
191+
fn test_column_pool_take_panics_when_empty() {
192+
let mut pool = ColumnPool::<CpuBackend>::new();
193+
pool.take(4);
194+
}
195+
196+
#[test]
197+
fn test_prover_mempool_with_requirements() {
198+
let workspace = ProverMemPool::<CpuBackend>::with_requirements(&[(4, 2), (5, 3), (6, 1)]);
199+
assert_eq!(workspace.column_pool.available(4), 2);
200+
assert_eq!(workspace.column_pool.available(5), 3);
201+
assert_eq!(workspace.column_pool.available(6), 1);
202+
assert_eq!(workspace.column_pool.total_available(), 6);
203+
}
204+
205+
use num_traits::Zero;
206+
}

crates/stwo/src/prover/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod poly;
2424
pub mod secure_column;
2525
pub mod vcs;
2626
pub mod vcs_lifted;
27+
pub mod mempool;
2728

2829
pub fn prove<B: BackendForChannel<MC>, MC: MerkleChannel>(
2930
components: &[&dyn ComponentProver<B>],

0 commit comments

Comments
 (0)