Skip to content

Commit 83644d5

Browse files
committed
Add SIMD implementation of poseidon MerkleOpsLifted. (#1280)
1 parent 3d937f6 commit 83644d5

File tree

4 files changed

+200
-10
lines changed

4 files changed

+200
-10
lines changed

crates/stwo/src/core/pcs/verifier.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
9090
.0
9191
.into_iter()
9292
.collect::<Result<(), _>>()?;
93-
9493
// Answer FRI queries.
9594
let samples = sampled_points.zip_cols(proof.sampled_values).map_cols(
9695
|(sampled_points, sampled_values)| {

crates/stwo/src/core/vcs_lifted/poseidon252_merkle.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ impl MerkleHasherLifted for Poseidon252MerkleHasher {
7373
}
7474

7575
pub fn poseidon_update(values: &[FieldElement252], state: &mut [FieldElement252; 3]) {
76-
debug_assert!(values.len().is_multiple_of(2));
7776
let mut iter = values.chunks_exact(2);
7877
for msg in iter.by_ref() {
7978
state[0] += msg[0];

crates/stwo/src/prover/backend/cpu/blake2s_lifted.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
#[cfg(feature = "parallel")]
2+
use rayon::iter::{IntoParallelIterator, ParallelIterator};
3+
14
use crate::core::fields::m31::BaseField;
25
use crate::core::vcs_lifted::merkle_hasher::MerkleHasherLifted;
6+
use crate::parallel_iter;
37
use crate::prover::backend::CpuBackend;
48
use crate::prover::vcs_lifted::ops::MerkleOpsLifted;
59

@@ -69,8 +73,8 @@ impl<H: MerkleHasherLifted> MerkleOpsLifted<H> for CpuBackend {
6973
}
7074

7175
fn build_next_layer(prev_layer: &Vec<H::Hash>) -> Vec<H::Hash> {
72-
let log_size = prev_layer.len().ilog2() as usize - 1;
73-
(0..(1 << log_size))
76+
let log_size: u32 = prev_layer.len().ilog2() - 1;
77+
parallel_iter!(0..(1 << log_size))
7478
.map(|i| H::hash_children((prev_layer[2 * i], prev_layer[2 * i + 1])))
7579
.collect()
7680
}
Lines changed: 194 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,209 @@
1-
use crate::core::fields::m31::BaseField;
1+
use itertools::Itertools;
2+
#[cfg(feature = "parallel")]
3+
use rayon::prelude::*;
4+
use starknet_ff::FieldElement as FieldElement252;
5+
6+
use crate::core::fields::m31::{BaseField, M31};
7+
use crate::core::vcs::poseidon252_merkle::{construct_felt252_from_m31s, ELEMENTS_IN_BLOCK};
28
use crate::core::vcs_lifted::merkle_hasher::MerkleHasherLifted;
3-
use crate::core::vcs_lifted::poseidon252_merkle::Poseidon252MerkleHasher;
9+
use crate::core::vcs_lifted::poseidon252_merkle::{
10+
poseidon_finalize, poseidon_update, Poseidon252MerkleHasher, ELEMENTS_IN_BUFFER,
11+
};
12+
#[cfg(feature = "parallel")]
13+
use crate::prover::backend::simd::m31::N_LANES;
414
use crate::prover::backend::simd::SimdBackend;
5-
use crate::prover::backend::Col;
15+
use crate::prover::backend::{Col, Column, CpuBackend};
616
use crate::prover::vcs_lifted::ops::MerkleOpsLifted;
717

8-
#[allow(unused)]
18+
/// TODO(Leo): the implementation below is not really vectorized because there is no poseidon hash
19+
/// implementation in simd yet.
920
impl MerkleOpsLifted<Poseidon252MerkleHasher> for SimdBackend {
1021
fn build_leaves(
1122
columns: &[&Col<Self, BaseField>],
1223
) -> Col<Self, <Poseidon252MerkleHasher as MerkleHasherLifted>::Hash> {
13-
unimplemented!()
24+
if columns.is_empty() {
25+
return vec![<Poseidon252MerkleHasher as MerkleHasherLifted>::Hash::default()];
26+
}
27+
if columns.first().unwrap().len() < N_LANES {
28+
let cpu_cols = columns.iter().map(|column| column.to_cpu()).collect_vec();
29+
return <CpuBackend as MerkleOpsLifted<Poseidon252MerkleHasher>>::build_leaves(
30+
&cpu_cols.iter().collect_vec(),
31+
);
32+
}
33+
let max_log_size: u32 = columns.last().unwrap().len().ilog2();
34+
let mut col_chunk_iter = columns.chunks(ELEMENTS_IN_BUFFER);
35+
let last_chunk = unsafe { col_chunk_iter.next_back().unwrap_unchecked() };
36+
37+
// Preallocate working memory.
38+
// For every chunk of column, we go over all the rows, read from `prev_layer_states`, write
39+
// to `next_layer_states`, and then we swap them for the next chunk.
40+
let mut prev_layer_states: Vec<[FieldElement252; 3]> =
41+
vec![[FieldElement252::default(); 3]; 1 << (max_log_size)];
42+
let mut next_layer_states: Vec<[FieldElement252; 3]> =
43+
vec![[FieldElement252::default(); 3]; 1 << (max_log_size)];
44+
45+
let mut prev_chunk_max_log_size = 1;
46+
for chunk_columns in &mut col_chunk_iter {
47+
let chunk_max_log_size: u32 = chunk_columns.iter().last().unwrap().len().ilog2();
48+
let next_layer_state_slice = &mut next_layer_states[0..1 << chunk_max_log_size];
49+
// Compute the new states of the current layer.
50+
#[cfg(not(feature = "parallel"))]
51+
let iter_states = next_layer_state_slice.iter_mut();
52+
#[cfg(feature = "parallel")]
53+
let iter_states = next_layer_state_slice.par_iter_mut();
54+
55+
iter_states.enumerate().for_each(|(i, curr_state)| {
56+
let log_ratio = chunk_max_log_size - prev_chunk_max_log_size;
57+
let mut prev_state: [FieldElement252; 3] =
58+
prev_layer_states[(i >> (log_ratio + 1) << 1) + (i & 1)];
59+
let mut msgs: [M31; ELEMENTS_IN_BUFFER] = unsafe { std::mem::zeroed() };
60+
for (j, column) in chunk_columns.iter().enumerate() {
61+
let log_size = column.len().ilog2();
62+
let log_ratio = chunk_max_log_size - log_size;
63+
msgs[j] = column.at((i >> (log_ratio + 1) << 1) + (i & 1));
64+
}
65+
poseidon_update_m31s(&msgs, &mut prev_state);
66+
*curr_state = prev_state;
67+
});
68+
std::mem::swap(&mut prev_layer_states, &mut next_layer_states);
69+
prev_chunk_max_log_size = chunk_max_log_size;
70+
}
71+
72+
#[cfg(not(feature = "parallel"))]
73+
let iter_states = next_layer_states.iter_mut();
74+
#[cfg(feature = "parallel")]
75+
let iter_states = next_layer_states.par_iter_mut();
76+
77+
iter_states.enumerate().for_each(|(i, curr_state)| {
78+
let log_ratio = max_log_size - prev_chunk_max_log_size;
79+
let prev_state: [FieldElement252; 3] =
80+
prev_layer_states[(i >> (log_ratio + 1) << 1) + (i & 1)];
81+
let mut msgs: [M31; ELEMENTS_IN_BUFFER] = unsafe { std::mem::zeroed() };
82+
for (j, column) in last_chunk.iter().enumerate() {
83+
let log_size = column.len().ilog2();
84+
let log_ratio = max_log_size - log_size;
85+
msgs[j] = column.at((i >> (log_ratio + 1) << 1) + (i & 1));
86+
}
87+
*curr_state = poseidon_finalize_m31s(&msgs[..last_chunk.len()], prev_state);
88+
});
89+
next_layer_states.iter().map(|[fin, ..]| *fin).collect()
1490
}
1591

1692
fn build_next_layer(
1793
prev_layer: &Col<Self, <Poseidon252MerkleHasher as MerkleHasherLifted>::Hash>,
1894
) -> Col<Self, <Poseidon252MerkleHasher as MerkleHasherLifted>::Hash> {
19-
unimplemented!()
95+
<CpuBackend as MerkleOpsLifted<Poseidon252MerkleHasher>>::build_next_layer(prev_layer)
96+
}
97+
}
98+
99+
fn poseidon_update_m31s(msgs: &[M31; ELEMENTS_IN_BUFFER], prev_state: &mut [FieldElement252; 3]) {
100+
let field_elements: [FieldElement252; 2] = std::array::from_fn(|i| {
101+
construct_felt252_from_m31s(&msgs[i * ELEMENTS_IN_BLOCK..(i + 1) * ELEMENTS_IN_BLOCK])
102+
});
103+
poseidon_update(&field_elements, prev_state);
104+
}
105+
106+
fn poseidon_finalize_m31s(msgs: &[M31], prev_state: [FieldElement252; 3]) -> [FieldElement252; 3] {
107+
let field_elements: Vec<FieldElement252> = msgs
108+
.chunks(ELEMENTS_IN_BLOCK)
109+
.map(construct_felt252_from_m31s)
110+
.collect();
111+
poseidon_finalize(&field_elements, prev_state)
112+
}
113+
114+
#[cfg(test)]
115+
mod tests {
116+
use itertools::Itertools;
117+
118+
use super::FieldElement252;
119+
use crate::core::fields::m31::{BaseField, M31};
120+
use crate::core::vcs_lifted::poseidon252_merkle::Poseidon252MerkleHasher;
121+
use crate::prover::backend::simd::column::BaseColumn;
122+
use crate::prover::backend::simd::SimdBackend;
123+
use crate::prover::backend::CpuBackend;
124+
use crate::prover::vcs_lifted::ops::MerkleOpsLifted;
125+
use crate::prover::vcs_lifted::prover::MerkleProverLifted;
126+
127+
#[test]
128+
fn test_build_next_layer() {
129+
const LOG_SIZE: u32 = 6;
130+
let layer: Vec<FieldElement252> = (0u32..1 << (LOG_SIZE + 1))
131+
.map(FieldElement252::from)
132+
.collect();
133+
assert_eq!(
134+
<CpuBackend as MerkleOpsLifted<Poseidon252MerkleHasher>>::build_next_layer(&layer),
135+
<SimdBackend as MerkleOpsLifted<Poseidon252MerkleHasher>>::build_next_layer(&layer)
136+
);
137+
}
138+
139+
fn prepare_poseidon_merkle_commit() -> (FieldElement252, FieldElement252) {
140+
const MAX_LOG_N_ROWS: u32 = 9;
141+
const N_COLS: u32 = 95;
142+
let mut cols: Vec<Vec<BaseField>> = (0..N_COLS)
143+
.map(|i| {
144+
(0..1 << MAX_LOG_N_ROWS)
145+
.map(|j| M31::from(100 * i + j))
146+
.collect_vec()
147+
})
148+
.collect();
149+
150+
// Make the first two columns smaller to test a non-uniform sized trace.
151+
(0..20).for_each(|i| {
152+
cols[i] = (0..1 << (MAX_LOG_N_ROWS - 4))
153+
.map(M31::from_u32_unchecked)
154+
.collect_vec()
155+
});
156+
(20..40).for_each(|i| {
157+
cols[i] = (0..1 << (MAX_LOG_N_ROWS - 3))
158+
.map(M31::from_u32_unchecked)
159+
.collect_vec()
160+
});
161+
let cols_simd: Vec<BaseColumn> = cols
162+
.iter()
163+
.map(|c| BaseColumn::from_cpu(c.clone()))
164+
.collect();
165+
166+
(
167+
MerkleProverLifted::<CpuBackend, Poseidon252MerkleHasher>::commit(
168+
cols.iter().collect(),
169+
)
170+
.root(),
171+
MerkleProverLifted::<SimdBackend, Poseidon252MerkleHasher>::commit(
172+
cols_simd.iter().collect(),
173+
)
174+
.root(),
175+
)
176+
}
177+
178+
#[test]
179+
fn test_poseidon_merkle_commit() {
180+
let (cpu_root, simd_root) = prepare_poseidon_merkle_commit();
181+
assert_eq!(cpu_root, simd_root);
182+
}
183+
#[test]
184+
fn test_small_columns_leaves() {
185+
for log_size in 2..9 {
186+
const N_COLS: usize = 2;
187+
let cols: Vec<Vec<BaseField>> = (0..N_COLS)
188+
.map(|i| {
189+
(0..1 << log_size)
190+
.map(|j| M31::from(100 * i + j))
191+
.collect_vec()
192+
})
193+
.collect();
194+
let cols_simd: Vec<BaseColumn> = cols
195+
.iter()
196+
.map(|c| BaseColumn::from_cpu(c.clone()))
197+
.collect();
198+
199+
assert_eq!(
200+
<CpuBackend as MerkleOpsLifted<Poseidon252MerkleHasher>>::build_leaves(
201+
&cols.iter().collect::<Vec<_>>()
202+
),
203+
<SimdBackend as MerkleOpsLifted<Poseidon252MerkleHasher>>::build_leaves(
204+
&cols_simd.iter().collect::<Vec<_>>()
205+
)
206+
);
207+
}
20208
}
21209
}

0 commit comments

Comments
 (0)