Skip to content

Commit efebeec

Browse files
committed
Updated ret opcode test.
1 parent c18b4a3 commit efebeec

File tree

3 files changed

+184
-26
lines changed

3 files changed

+184
-26
lines changed

crates/prover/src/components/ret_opcode/component.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use num_traits::One;
22
use serde::{Deserialize, Serialize};
3-
use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, RelationEntry};
3+
use stwo_prover::constraint_framework::{
4+
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
5+
};
46
use stwo_prover::core::channel::Channel;
57
use stwo_prover::core::fields::m31::M31;
68
use stwo_prover::core::fields::qm31::SecureField;
@@ -12,7 +14,7 @@ use crate::utils::component::log_size;
1214

1315
pub const RET_N_TRACE_CELLS: usize = 5;
1416
// TODO(alont): set instruction bases to not overlap
15-
pub const RET_INSTRUCTION: M31 = M31::from_u32_unchecked(0);
17+
pub const INSTRUCTION_BASE: M31 = M31::from_u32_unchecked(0);
1618
pub type Component = FrameworkComponent<Eval>;
1719

1820
#[derive(Clone)]
@@ -21,17 +23,26 @@ pub struct Eval {
2123
pub memory_lookup: MemoryRelation,
2224
pub state_lookup: StateRelation,
2325
}
24-
2526
impl Eval {
26-
pub fn log_size(&self) -> u32 {
27+
pub fn new(claim: Claim, memory_lookup: MemoryRelation, state_lookup: StateRelation) -> Self {
28+
Self {
29+
claim: claim.clone(),
30+
memory_lookup,
31+
state_lookup,
32+
}
33+
}
34+
}
35+
36+
impl FrameworkEval for Eval {
37+
fn log_size(&self) -> u32 {
2738
log_size(self.claim.n_rows)
2839
}
2940

30-
pub fn max_constraint_log_degree_bound(&self) -> u32 {
41+
fn max_constraint_log_degree_bound(&self) -> u32 {
3142
self.log_size() + 1
3243
}
3344

34-
pub fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
45+
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
3546
// Initial state.
3647
let state = std::array::from_fn(|_| eval.next_trace_mask());
3748
// Use initial state.
@@ -42,7 +53,7 @@ impl Eval {
4253
eval.add_to_relation(RelationEntry::new(
4354
&self.memory_lookup,
4455
E::EF::one(),
45-
&[pc, RET_INSTRUCTION.into()],
56+
&[pc, INSTRUCTION_BASE.into()],
4657
));
4758

4859
// FP - 1

crates/prover/src/components/ret_opcode/mod.rs

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,124 @@ pub mod prover;
33

44
pub use component::{Claim, Component, Eval, InteractionClaim};
55
pub use prover::ClaimGenerator;
6+
7+
#[cfg(test)]
8+
mod tests {
9+
10+
use itertools::Itertools;
11+
use num_traits::Zero;
12+
use stwo_prover::constraint_framework::{
13+
FrameworkComponent, FrameworkEval, TraceLocationAllocator,
14+
};
15+
use stwo_prover::core::backend::simd::qm31::PackedSecureField;
16+
use stwo_prover::core::backend::simd::SimdBackend;
17+
use stwo_prover::core::channel::Blake2sChannel;
18+
use stwo_prover::core::fields::m31::M31;
19+
use stwo_prover::core::fields::qm31::QM31;
20+
use stwo_prover::core::pcs::{CommitmentSchemeProver, PcsConfig};
21+
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};
22+
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
23+
24+
use super::*;
25+
use crate::components::memory;
26+
use crate::components::ret_opcode::component::INSTRUCTION_BASE;
27+
use crate::input::instructions::VmState;
28+
use crate::relations;
29+
30+
#[test]
31+
fn test_ret_opcode() {
32+
const LOG_HEIGHT: u32 = 8;
33+
const LOG_BLOWUP_FACTOR: u32 = 1;
34+
35+
// Initialize at pc=0, ap=fp=3 with:
36+
// pc -> 0: ret
37+
// 1: 1234
38+
// 2: 5678
39+
// fp -> 3: 0
40+
let mut memory_claim_generator = memory::ClaimGenerator {
41+
values: vec![PackedSecureField::from_array([
42+
QM31::from_m31_array([INSTRUCTION_BASE, M31(0), M31(0), M31(0)]),
43+
QM31::from_m31_array([M31(1234), M31(1), M31(2), M31(1)]),
44+
QM31::from_m31_array([M31(5678), M31(1), M31(2), M31(1)]),
45+
QM31::zero(),
46+
QM31::zero(),
47+
QM31::zero(),
48+
QM31::zero(),
49+
QM31::zero(),
50+
QM31::zero(),
51+
QM31::zero(),
52+
QM31::zero(),
53+
QM31::zero(),
54+
QM31::zero(),
55+
QM31::zero(),
56+
QM31::zero(),
57+
QM31::zero(),
58+
])],
59+
// Dummy multiplicities
60+
multiplicities: vec![1; 16],
61+
};
62+
63+
let claim_generator = ClaimGenerator::new(vec![
64+
VmState {
65+
pc: 0,
66+
ap: 3,
67+
fp: 3,
68+
};
69+
256
70+
]);
71+
72+
let twiddles = SimdBackend::precompute_twiddles(
73+
CanonicCoset::new(LOG_HEIGHT + LOG_BLOWUP_FACTOR)
74+
.circle_domain()
75+
.half_coset,
76+
);
77+
78+
let channel = &mut Blake2sChannel::default();
79+
let config = PcsConfig::default();
80+
let commitment_scheme =
81+
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
82+
config, &twiddles,
83+
);
84+
85+
// Preprocessed.
86+
let tree_builder = commitment_scheme.tree_builder();
87+
tree_builder.commit(channel);
88+
89+
let mut tree_builder = commitment_scheme.tree_builder();
90+
let (claim, interaction_claim_generator) =
91+
claim_generator.write_trace(&mut tree_builder, &mut memory_claim_generator);
92+
93+
tree_builder.commit(channel);
94+
let mut tree_builder = commitment_scheme.tree_builder();
95+
96+
let memory_relation = relations::MemoryRelation::draw(channel);
97+
let state_relation = relations::StateRelation::draw(channel);
98+
let interaction_claim = interaction_claim_generator.write_interaction_trace(
99+
&mut tree_builder,
100+
&memory_relation,
101+
&state_relation,
102+
);
103+
tree_builder.commit(channel);
104+
105+
let trace_location_allocator = &mut TraceLocationAllocator::default();
106+
let component = FrameworkComponent::new(
107+
trace_location_allocator,
108+
Eval::new(claim, memory_relation, state_relation),
109+
interaction_claim.claimed_sum,
110+
);
111+
112+
let trace_polys = commitment_scheme
113+
.trees
114+
.as_ref()
115+
.map(|t| t.polynomials.iter().cloned().collect_vec());
116+
117+
stwo_prover::constraint_framework::assert_constraints(
118+
&trace_polys,
119+
CanonicCoset::new(LOG_HEIGHT),
120+
|eval| {
121+
component.evaluate(eval);
122+
},
123+
interaction_claim.claimed_sum,
124+
)
125+
}
126+
}

crates/prover/src/components/ret_opcode/prover.rs

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use stwo_prover::core::pcs::TreeBuilder;
1313
use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order;
1414
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
1515

16-
use super::component::{Claim, InteractionClaim, RET_INSTRUCTION};
16+
use super::component::{Claim, InteractionClaim, INSTRUCTION_BASE};
1717
use crate::components::memory;
1818
use crate::input::instructions::VmState;
1919
use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE};
@@ -77,7 +77,7 @@ impl ClaimGenerator {
7777
) -> (Claim, InteractionClaimGenerator) {
7878
let (trace, lookup_data) = write_trace_simd(&self.inputs, memory_trace_generator);
7979

80-
let n_rows = self.inputs.len();
80+
let n_rows = self.inputs.len() * N_LANES;
8181
assert_ne!(n_rows, 0);
8282

8383
lookup_data.memory.iter().for_each(|c| {
@@ -116,9 +116,13 @@ impl InteractionClaimGenerator {
116116
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
117117
let mut logup_gen = LogupTraceGenerator::new(log_size);
118118

119-
let mut col0 = logup_gen.new_col();
120119
let state_use = &self.lookup_data.state[0];
121120
let read_pc = &self.lookup_data.memory[0];
121+
let fp_minus_one = &self.lookup_data.memory[1];
122+
let fp_minus_two = &self.lookup_data.memory[2];
123+
let state_yield = &self.lookup_data.state[1];
124+
125+
let mut col0 = logup_gen.new_col();
122126
for (i, (x, y)) in zip_eq(state_use, read_pc).enumerate() {
123127
let denom_x: PackedQM31 = state_relation.combine(x);
124128
let denom_y: PackedQM31 = memory_relation.combine(y);
@@ -127,13 +131,22 @@ impl InteractionClaimGenerator {
127131
}
128132
col0.finalize_col();
129133

130-
let mut col_gen = logup_gen.new_col();
131-
let state_yield = &self.lookup_data.state[1];
132-
for (i, values) in state_yield.iter().enumerate() {
133-
let denom: PackedQM31 = state_relation.combine(values);
134-
col_gen.write_frac(i, -PackedQM31::one(), denom);
134+
let mut col1 = logup_gen.new_col();
135+
for (i, (x, y)) in zip_eq(fp_minus_one, fp_minus_two).enumerate() {
136+
let denom_x: PackedQM31 = memory_relation.combine(x);
137+
let denom_y: PackedQM31 = memory_relation.combine(y);
138+
139+
col1.write_frac(i, denom_x + denom_y, denom_x * denom_y)
140+
}
141+
col1.finalize_col();
142+
143+
let mut col2 = logup_gen.new_col();
144+
for (i, x) in state_yield.iter().enumerate() {
145+
let denom_x: PackedQM31 = state_relation.combine(x);
146+
147+
col2.write_frac(i, -PackedQM31::one(), denom_x)
135148
}
136-
col_gen.finalize_col();
149+
col2.finalize_col();
137150

138151
let (trace, claimed_sum) = logup_gen.finalize_last();
139152
tree_builder.extend_evals(trace);
@@ -164,13 +177,13 @@ fn write_trace_simd(
164177
.par_iter_mut()
165178
.zip(inputs.par_iter())
166179
.zip(lookup_data.par_iter_mut())
167-
.for_each(|((row, ret_opcode_input), lookup_data)| {
168-
let col0_pc = ret_opcode_input.pc;
180+
.for_each(|((row, input), lookup_data)| {
181+
let col0_pc = input.pc;
169182
*row[0] = col0_pc;
170183
// Not added to memory inputs: `ap` not part of constraint yet.
171-
let col1_ap = ret_opcode_input.ap;
184+
let col1_ap = input.ap;
172185
*row[1] = col1_ap;
173-
let col2_fp = ret_opcode_input.fp;
186+
let col2_fp = input.fp;
174187
*row[2] = col2_fp;
175188
let mem_fp_minus_one = memory_trace_generator
176189
.deduce_output((col2_fp) - (PackedM31::broadcast(M31::one())));
@@ -184,18 +197,31 @@ fn write_trace_simd(
184197

185198
*lookup_data.memory[0] = std::array::from_fn(|i| match i {
186199
0 => col0_pc,
187-
1 => PackedM31::broadcast(RET_INSTRUCTION),
200+
1 => PackedM31::broadcast(INSTRUCTION_BASE),
188201
_ => PackedM31::zero(),
189202
});
190203

191-
let [v0, v1, v2, v3] = mem_fp_minus_one.into_packed_m31s();
192-
*lookup_data.memory[1] = [col2_fp - PackedM31::broadcast(M31::one()), v0, v1, v2, v3];
193-
194-
let [v0, v1, v2, v3] = mem_fp_minus_two.into_packed_m31s();
195-
*lookup_data.memory[2] = [col2_fp - PackedM31::broadcast(M31::from(2)), v0, v1, v2, v3];
204+
let [new_pc, _, _, _] = mem_fp_minus_one.into_packed_m31s();
205+
*lookup_data.memory[1] = [
206+
col2_fp - PackedM31::broadcast(M31::one()),
207+
new_pc,
208+
PackedM31::zero(),
209+
PackedM31::zero(),
210+
PackedM31::zero(),
211+
];
212+
213+
let [new_fp, _, _, _] = mem_fp_minus_two.into_packed_m31s();
214+
*lookup_data.memory[2] = [
215+
col2_fp - PackedM31::broadcast(M31::from(2)),
216+
new_fp,
217+
PackedM31::zero(),
218+
PackedM31::zero(),
219+
PackedM31::zero(),
220+
];
196221

197222
let col4 = mem_fp_minus_two;
198223
*row[4] = col4.into_packed_m31s()[0];
224+
*lookup_data.state[1] = [new_pc, input.ap, new_fp];
199225
});
200226

201227
(trace, lookup_data)

0 commit comments

Comments
 (0)