Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions crates/examples/src/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,72 @@ mod tests {
}
}

/// Same as [test_wide_fib_prove_with_blake] but with FRI fold step > 1.
#[test]
fn test_wide_fib_prove_with_blake_with_fri_jumps() {
for log_n_instances in 4..=8 {
let mut config = PcsConfig::default();
// Test different steps.
config.fri_config.line_fold_step = if (4..6).contains(&log_n_instances) {
2
} else {
3
};
// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor)
.circle_domain()
.half_coset,
);

// Setup protocol.
let prover_channel = &mut Blake2sM31Channel::default();
let mut commitment_scheme = CommitmentSchemeProver::<
SimdBackend,
Blake2sM31MerkleChannel,
>::new(config, &twiddles);

// Preprocessed trace
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(vec![]);
tree_builder.commit(prover_channel);

// Trace.
let trace =
generate_trace::<FIB_SEQUENCE_LENGTH, _>(&generate_test_inputs(log_n_instances));
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace);
tree_builder.commit(prover_channel);

// Prove constraints.
let component = WideFibonacciComponent::new(
&mut TraceLocationAllocator::default(),
WideFibonacciEval::<FIB_SEQUENCE_LENGTH> {
log_n_rows: log_n_instances,
},
SecureField::zero(),
);

let proof = prove::<SimdBackend, Blake2sM31MerkleChannel>(
&[&component],
prover_channel,
commitment_scheme,
)
.unwrap();

// Verify.
let verifier_channel = &mut Blake2sM31Channel::default();
let commitment_scheme =
&mut CommitmentSchemeVerifier::<Blake2sM31MerkleChannel>::new(config);

// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();
commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel);
commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel);
verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap();
}
}

#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_wide_fib_prove_with_poseidon() {
Expand Down
116 changes: 90 additions & 26 deletions crates/stwo/src/prover/backend/simd/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::array;
use std::simd::{u32x16, u32x8};

use num_traits::Zero;
#[cfg(feature = "parallel")]
use rayon::iter::{IndexedParallelIterator, ParallelIterator};

use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::SimdBackend;
Expand All @@ -22,6 +24,8 @@ use crate::prover::poly::twiddles::TwiddleTree;
use crate::prover::poly::BitReversedOrder;
use crate::prover::secure_column::SecureColumnByCoords;

const FOLD_LINE_CHUNK_SIZE: usize = 128;

// TODO(andrew) Is this optimized?
impl FriOps for SimdBackend {
fn fold_line(
Expand All @@ -30,39 +34,99 @@ impl FriOps for SimdBackend {
twiddles: &TwiddleTree<Self>,
fold_step: u32,
) -> LineEvaluation<Self> {
// TODO(Leo): remove in next PRs.
assert_eq!(fold_step, 1, "FRI jumps not yet supported in SIMD backend.");
assert!(fold_step >= 1, "fold_step must be positive.");

let log_size = eval.len().ilog2();
if log_size <= LOG_N_LANES {
let eval = fold_line_cpu(&eval.to_cpu(), alpha);
// Fallback to cpu if the log size is too small.
if log_size < LOG_N_LANES + fold_step {
let mut folding_alpha = alpha;
let mut eval = fold_line_cpu(&eval.to_cpu(), folding_alpha);
for _ in 0..fold_step - 1 {
folding_alpha = folding_alpha * folding_alpha;
eval = fold_line_cpu(&eval, folding_alpha)
}
return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect());
}
let mut alphas = vec![];
let mut folding_alpha = alpha;
for _ in 0..fold_step {
alphas.push(folding_alpha);
folding_alpha = folding_alpha * folding_alpha;
}

let domain = eval.domain();
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

let all_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles);
let mut folded_values =
unsafe { SecureColumnByCoords::<Self>::uninitialized(1 << (log_size - 1)) };

for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
let value = {
let twiddle_dbl = u32x16::from_array(array::from_fn(|i| unsafe {
*itwiddles.get_unchecked(vec_index * 16 + i)
}));
let val0 = unsafe { eval.values.packed_at(vec_index * 2) }.into_packed_m31s();
let val1 = unsafe { eval.values.packed_at(vec_index * 2 + 1) }.into_packed_m31s();
let pairs: [_; 4] = array::from_fn(|i| {
let (a, b) = val0[i].deinterleave(val1[i]);
simd_ibutterfly(a, b, twiddle_dbl)
});
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe { folded_values.set_packed(vec_index, value) };
}
unsafe { SecureColumnByCoords::uninitialized(1 << (log_size - fold_step)) };

#[cfg(not(feature = "parallel"))]
let folded_values_iter = folded_values.chunks_mut(FOLD_LINE_CHUNK_SIZE);
#[cfg(feature = "parallel")]
let folded_values_iter = folded_values.par_chunks_mut(FOLD_LINE_CHUNK_SIZE);

folded_values_iter
.enumerate()
.for_each(|(chunk_idx, mut dst_chunk)| {
let chunk_start = chunk_idx * FOLD_LINE_CHUNK_SIZE;
let mut layer_values: Vec<[PackedBaseField; 4]> =
Vec::with_capacity(1 << fold_step);
let mut next_layer_values: Vec<[PackedBaseField; 4]> =
Vec::with_capacity(1 << fold_step);
let packed_chunk_len = dst_chunk.0[0].0.len();

for local_i in 0..packed_chunk_len {
let i = chunk_start + local_i;

// Read the packed inputs needed for a full fold.
layer_values.clear();
let input_base = i << fold_step;
unsafe {
for j in 0..1 << fold_step {
layer_values
.push(eval.values.packed_at(input_base + j).into_packed_m31s());
}
}

for layer in 0..fold_step as usize {
let next_len = layer_values.len() / 2;
let itwiddles = all_twiddles[layer];
let alpha = alphas[layer];
next_layer_values.clear();
unsafe {
for j in 0..next_len {
let twiddle_dbl = u32x16::from_array(array::from_fn(|k| {
*itwiddles.get_unchecked((i * next_len + j) * 16 + k)
}));
let val0 = layer_values[2 * j];
let val1 = layer_values[2 * j + 1];
let pairs: [_; 4] = array::from_fn(|c| {
let (a, b) = val0[c].deinterleave(val1[c]);
simd_ibutterfly(a, b, twiddle_dbl)
});
let v0 = PackedSecureField::from_packed_m31s(array::from_fn(|c| {
pairs[c].0
}));
let v1 = PackedSecureField::from_packed_m31s(array::from_fn(|c| {
pairs[c].1
}));
next_layer_values.push(
(v0 + PackedSecureField::broadcast(alpha) * v1)
.into_packed_m31s(),
);
}
}
std::mem::swap(&mut layer_values, &mut next_layer_values);
}
let result = layer_values[0];

unsafe {
dst_chunk.set_packed(local_i, PackedSecureField::from_packed_m31s(result));
}
}
});

LineEvaluation::new(domain.double(), folded_values)
let new_domain = domain.repeated_double(fold_step);
LineEvaluation::new(new_domain, folded_values)
}

fn fold_circle_into_line(
Expand Down