Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions crates/prover/src/components/ret_opcode/component.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use num_traits::One;
use serde::{Deserialize, Serialize};
use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, RelationEntry};
use stwo_prover::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
};
use stwo_prover::core::channel::Channel;
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::qm31::SecureField;
Expand All @@ -12,7 +14,7 @@ use crate::utils::component::log_size;

pub const RET_N_TRACE_CELLS: usize = 5;
// TODO(alont): set instruction bases to not overlap
pub const RET_INSTRUCTION: M31 = M31::from_u32_unchecked(0);
pub const INSTRUCTION_BASE: M31 = M31::from_u32_unchecked(0);
pub type Component = FrameworkComponent<Eval>;

#[derive(Clone)]
Expand All @@ -21,17 +23,26 @@ pub struct Eval {
pub memory_lookup: MemoryRelation,
pub state_lookup: StateRelation,
}

impl Eval {
pub fn log_size(&self) -> u32 {
pub fn new(claim: Claim, memory_lookup: MemoryRelation, state_lookup: StateRelation) -> Self {
Self {
claim: claim.clone(),
memory_lookup,
state_lookup,
}
}
}

impl FrameworkEval for Eval {
fn log_size(&self) -> u32 {
log_size(self.claim.n_rows)
}

pub fn max_constraint_log_degree_bound(&self) -> u32 {
fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_size() + 1
}

pub fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
// Initial state.
let state = std::array::from_fn(|_| eval.next_trace_mask());
// Use initial state.
Expand All @@ -42,7 +53,7 @@ impl Eval {
eval.add_to_relation(RelationEntry::new(
&self.memory_lookup,
E::EF::one(),
&[pc, RET_INSTRUCTION.into()],
&[pc, INSTRUCTION_BASE.into()],
));

// FP - 1
Expand Down
121 changes: 121 additions & 0 deletions crates/prover/src/components/ret_opcode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,124 @@ pub mod prover;

pub use component::{Claim, Component, Eval, InteractionClaim};
pub use prover::ClaimGenerator;

#[cfg(test)]
mod tests {

use itertools::Itertools;
use num_traits::Zero;
use stwo_prover::constraint_framework::{
FrameworkComponent, FrameworkEval, TraceLocationAllocator,
};
use stwo_prover::core::backend::simd::qm31::PackedSecureField;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::channel::Blake2sChannel;
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::qm31::QM31;
use stwo_prover::core::pcs::{CommitmentSchemeProver, PcsConfig};
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::*;
use crate::components::memory;
use crate::components::ret_opcode::component::INSTRUCTION_BASE;
use crate::input::instructions::VmState;
use crate::relations;

#[test]
fn test_ret_opcode() {
const LOG_HEIGHT: u32 = 8;
const LOG_BLOWUP_FACTOR: u32 = 1;

// Initialize at pc=0, ap=fp=3 with:
// pc -> 0: ret
// 1: 1234
// 2: 5678
// fp -> 3: 0
let mut memory_claim_generator = memory::ClaimGenerator {
values: vec![PackedSecureField::from_array([
QM31::from_m31_array([INSTRUCTION_BASE, M31(0), M31(0), M31(0)]),
QM31::from_m31_array([M31(1234), M31(1), M31(2), M31(1)]),
QM31::from_m31_array([M31(5678), M31(1), M31(2), M31(1)]),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
QM31::zero(),
])],
// Dummy multiplicities
multiplicities: vec![1; 16],
};

let claim_generator = ClaimGenerator::new(vec![
VmState {
pc: 0,
ap: 3,
fp: 3,
};
256
]);

let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(LOG_HEIGHT + LOG_BLOWUP_FACTOR)
.circle_domain()
.half_coset,
);

let channel = &mut Blake2sChannel::default();
let config = PcsConfig::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
config, &twiddles,
);

// Preprocessed.
let tree_builder = commitment_scheme.tree_builder();
tree_builder.commit(channel);

let mut tree_builder = commitment_scheme.tree_builder();
let (claim, interaction_claim_generator) =
claim_generator.write_trace(&mut tree_builder, &mut memory_claim_generator);

tree_builder.commit(channel);
let mut tree_builder = commitment_scheme.tree_builder();

let memory_relation = relations::MemoryRelation::draw(channel);
let state_relation = relations::StateRelation::draw(channel);
let interaction_claim = interaction_claim_generator.write_interaction_trace(
&mut tree_builder,
&memory_relation,
&state_relation,
);
tree_builder.commit(channel);

let trace_location_allocator = &mut TraceLocationAllocator::default();
let component = FrameworkComponent::new(
trace_location_allocator,
Eval::new(claim, memory_relation, state_relation),
interaction_claim.claimed_sum,
);

let trace_polys = commitment_scheme
.trees
.as_ref()
.map(|t| t.polynomials.iter().cloned().collect_vec());

stwo_prover::constraint_framework::assert_constraints(
&trace_polys,
CanonicCoset::new(LOG_HEIGHT),
|eval| {
component.evaluate(eval);
},
interaction_claim.claimed_sum,
)
}
}
64 changes: 45 additions & 19 deletions crates/prover/src/components/ret_opcode/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use stwo_prover::core::pcs::TreeBuilder;
use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::component::{Claim, InteractionClaim, RET_INSTRUCTION};
use super::component::{Claim, InteractionClaim, INSTRUCTION_BASE};
use crate::components::memory;
use crate::input::instructions::VmState;
use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE};
Expand Down Expand Up @@ -77,7 +77,7 @@ impl ClaimGenerator {
) -> (Claim, InteractionClaimGenerator) {
let (trace, lookup_data) = write_trace_simd(&self.inputs, memory_trace_generator);

let n_rows = self.inputs.len();
let n_rows = self.inputs.len() * N_LANES;
assert_ne!(n_rows, 0);

lookup_data.memory.iter().for_each(|c| {
Expand Down Expand Up @@ -116,9 +116,13 @@ impl InteractionClaimGenerator {
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
let mut logup_gen = LogupTraceGenerator::new(log_size);

let mut col0 = logup_gen.new_col();
let state_use = &self.lookup_data.state[0];
let read_pc = &self.lookup_data.memory[0];
let fp_minus_one = &self.lookup_data.memory[1];
let fp_minus_two = &self.lookup_data.memory[2];
let state_yield = &self.lookup_data.state[1];

let mut col0 = logup_gen.new_col();
for (i, (x, y)) in zip_eq(state_use, read_pc).enumerate() {
let denom_x: PackedQM31 = state_relation.combine(x);
let denom_y: PackedQM31 = memory_relation.combine(y);
Expand All @@ -127,13 +131,22 @@ impl InteractionClaimGenerator {
}
col0.finalize_col();

let mut col_gen = logup_gen.new_col();
let state_yield = &self.lookup_data.state[1];
for (i, values) in state_yield.iter().enumerate() {
let denom: PackedQM31 = state_relation.combine(values);
col_gen.write_frac(i, -PackedQM31::one(), denom);
let mut col1 = logup_gen.new_col();
for (i, (x, y)) in zip_eq(fp_minus_one, fp_minus_two).enumerate() {
let denom_x: PackedQM31 = memory_relation.combine(x);
let denom_y: PackedQM31 = memory_relation.combine(y);

col1.write_frac(i, denom_x + denom_y, denom_x * denom_y)
}
col1.finalize_col();

let mut col2 = logup_gen.new_col();
for (i, x) in state_yield.iter().enumerate() {
let denom_x: PackedQM31 = state_relation.combine(x);

col2.write_frac(i, -PackedQM31::one(), denom_x)
}
col_gen.finalize_col();
col2.finalize_col();

let (trace, claimed_sum) = logup_gen.finalize_last();
tree_builder.extend_evals(trace);
Expand Down Expand Up @@ -164,13 +177,13 @@ fn write_trace_simd(
.par_iter_mut()
.zip(inputs.par_iter())
.zip(lookup_data.par_iter_mut())
.for_each(|((row, ret_opcode_input), lookup_data)| {
let col0_pc = ret_opcode_input.pc;
.for_each(|((row, input), lookup_data)| {
let col0_pc = input.pc;
*row[0] = col0_pc;
// Not added to memory inputs: `ap` not part of constraint yet.
let col1_ap = ret_opcode_input.ap;
let col1_ap = input.ap;
*row[1] = col1_ap;
let col2_fp = ret_opcode_input.fp;
let col2_fp = input.fp;
*row[2] = col2_fp;
let mem_fp_minus_one = memory_trace_generator
.deduce_output((col2_fp) - (PackedM31::broadcast(M31::one())));
Expand All @@ -184,18 +197,31 @@ fn write_trace_simd(

*lookup_data.memory[0] = std::array::from_fn(|i| match i {
0 => col0_pc,
1 => PackedM31::broadcast(RET_INSTRUCTION),
1 => PackedM31::broadcast(INSTRUCTION_BASE),
_ => PackedM31::zero(),
});

let [v0, v1, v2, v3] = mem_fp_minus_one.into_packed_m31s();
*lookup_data.memory[1] = [col2_fp - PackedM31::broadcast(M31::one()), v0, v1, v2, v3];

let [v0, v1, v2, v3] = mem_fp_minus_two.into_packed_m31s();
*lookup_data.memory[2] = [col2_fp - PackedM31::broadcast(M31::from(2)), v0, v1, v2, v3];
let [new_pc, _, _, _] = mem_fp_minus_one.into_packed_m31s();
*lookup_data.memory[1] = [
col2_fp - PackedM31::broadcast(M31::one()),
new_pc,
PackedM31::zero(),
PackedM31::zero(),
PackedM31::zero(),
];

let [new_fp, _, _, _] = mem_fp_minus_two.into_packed_m31s();
*lookup_data.memory[2] = [
col2_fp - PackedM31::broadcast(M31::from(2)),
new_fp,
PackedM31::zero(),
PackedM31::zero(),
PackedM31::zero(),
];

let col4 = mem_fp_minus_two;
*row[4] = col4.into_packed_m31s()[0];
*lookup_data.state[1] = [new_pc, input.ap, new_fp];
});

(trace, lookup_data)
Expand Down
Loading