Skip to content

Commit fd6f654

Browse files
committed
Add jnz component
1 parent f0f2489 commit fd6f654

File tree

4 files changed

+574
-0
lines changed

4 files changed

+574
-0
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
use itertools::{chain, Itertools};
2+
use num_traits::One;
3+
use serde::{Deserialize, Serialize};
4+
use stwo_prover::constraint_framework::{
5+
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
6+
};
7+
use stwo_prover::core::backend::simd::m31::LOG_N_LANES;
8+
use stwo_prover::core::channel::Channel;
9+
use stwo_prover::core::fields::m31::M31;
10+
use stwo_prover::core::fields::qm31::SecureField;
11+
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
12+
use stwo_prover::core::pcs::TreeVec;
13+
14+
use crate::relations::{MemoryRelation, StateRelation};
15+
use crate::utils::component::{decode_opcode, is_bit};
16+
use crate::utils::{Selector, SelectorTrait};
17+
18+
pub const N_TRACE_CELLS: usize = 19;
19+
20+
// Assumes INSTRUCTION_BASE=K such that:
21+
/// `
22+
/// addap_imm = K
23+
/// jmp_abs_imm = K + 1
24+
/// jmp_rel_imm = K + 2
25+
/// `
26+
// TODO: organize opcodes so that K will work as detailed above, instead of just 0.
27+
pub const INSTRUCTION_BASE: M31 = M31::from_u32_unchecked(0);
28+
29+
pub type Component = FrameworkComponent<Eval>;
30+
31+
pub struct Eval {
32+
pub claim: Claim,
33+
pub memory_lookup: MemoryRelation,
34+
pub state_lookup: StateRelation,
35+
}
36+
37+
impl Eval {
38+
pub fn new(claim: Claim, memory_lookup: MemoryRelation, state_lookup: StateRelation) -> Self {
39+
Self {
40+
claim: claim.clone(),
41+
memory_lookup,
42+
state_lookup,
43+
}
44+
}
45+
}
46+
impl FrameworkEval for Eval {
47+
fn log_size(&self) -> u32 {
48+
std::cmp::max(self.claim.n_rows.next_power_of_two().ilog2(), LOG_N_LANES)
49+
}
50+
51+
fn max_constraint_log_degree_bound(&self) -> u32 {
52+
self.log_size() + 1
53+
}
54+
55+
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
56+
let state = std::array::from_fn(|_| eval.next_trace_mask());
57+
// Use initial state.
58+
eval.add_to_relation(RelationEntry::new(&self.state_lookup, E::EF::one(), &state));
59+
let [pc, ap, fp] = state;
60+
61+
// Assert flags are in range.
62+
let [op_type, reg] = std::array::from_fn(|_| eval.next_trace_mask());
63+
eval.add_constraint(is_bit::<E>(&op_type));
64+
eval.add_constraint(is_bit::<E>(&reg));
65+
66+
// Check instruction.
67+
let opcode = decode_opcode(
68+
INSTRUCTION_BASE.into(),
69+
&[
70+
(op_type.clone(), 2), // [jmp abs, jmp rel]
71+
(reg.clone(), 2), // [ap, fp]
72+
],
73+
);
74+
75+
let [off, imm] = std::array::from_fn(|_| eval.next_trace_mask());
76+
println!("1");
77+
eval.add_to_relation(RelationEntry::new(
78+
&self.memory_lookup,
79+
E::EF::one(),
80+
&[
81+
pc.clone(),
82+
opcode.clone(),
83+
off.clone(),
84+
imm.clone(),
85+
],
86+
));
87+
println!("2");
88+
// Compute address.
89+
let addr = eval.next_trace_mask();
90+
eval.add_constraint(
91+
addr.clone() - (Selector::select(&reg, [&(ap.clone()), &(fp.clone())]) + off),
92+
);
93+
94+
println!("3");
95+
let addr_val_arr: [E::F; 4] = std::array::from_fn(|_| eval.next_trace_mask());
96+
eval.add_to_relation(RelationEntry::new(
97+
&self.memory_lookup,
98+
E::EF::one(),
99+
&chain!([addr], addr_val_arr.clone()).collect_vec(),
100+
));
101+
let val = E::combine_ef(addr_val_arr);
102+
103+
104+
// Check jnz condition.
105+
let maybe_inverse_val = E::combine_ef(std::array::from_fn(|_| eval.next_trace_mask()));
106+
println!("4");
107+
let flag = eval.next_trace_mask();
108+
eval.add_constraint(is_bit::<E>(&flag));
109+
110+
// flag == 0 iff val == 0 iff val is not invertible <=>
111+
// ==> 0 = val * (1 - flag) + (1 - val * val^{-1}) * flag
112+
println!("5");
113+
eval.add_constraint(
114+
val.clone() * (E::F::one() - flag.clone())
115+
+ (E::EF::one() - val.clone() * maybe_inverse_val) * flag.clone(),
116+
);
117+
118+
// Assert new pc.
119+
// The relative branch when taken is obvious.
120+
println!("5.1");
121+
let jmp_target_if_taken =
122+
&Selector::select(&op_type, [&imm, &(pc.clone() + imm.clone())]);
123+
124+
println!("5.2");
125+
let new_pc = eval.next_trace_mask();
126+
127+
println!("6");
128+
eval.add_constraint(
129+
new_pc.clone() - Selector::select(&(E::F::one()-flag), [&(pc + E::F::one()), jmp_target_if_taken]),
130+
);
131+
132+
// Yield final state.
133+
let new_state = [new_pc, ap, fp];
134+
println!("7");
135+
eval.add_to_relation(RelationEntry::new(
136+
&self.state_lookup,
137+
-E::EF::one(),
138+
&new_state,
139+
));
140+
141+
println!("8");
142+
eval.finalize_logup_in_pairs();
143+
println!("9");
144+
eval
145+
}
146+
}
147+
148+
#[derive(Copy, Clone, Serialize, Deserialize)]
149+
pub struct Claim {
150+
pub n_rows: usize,
151+
}
152+
153+
impl Claim {
154+
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
155+
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
156+
let preprocessed_log_sizes = vec![log_size];
157+
let trace_log_sizes = vec![log_size; N_TRACE_CELLS];
158+
let interaction_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE * 3];
159+
TreeVec::new(vec![
160+
preprocessed_log_sizes,
161+
trace_log_sizes,
162+
interaction_log_sizes,
163+
])
164+
}
165+
166+
pub fn mix_into(&self, channel: &mut impl Channel) {
167+
channel.mix_u64(self.n_rows as u64);
168+
}
169+
}
170+
171+
#[derive(Clone, Serialize, Deserialize)]
172+
pub struct InteractionClaim {
173+
pub log_size: u32,
174+
pub claimed_sum: SecureField,
175+
}
176+
impl InteractionClaim {
177+
pub fn mix_into(&self, channel: &mut impl Channel) {
178+
channel.mix_felts(&[self.claimed_sum]);
179+
}
180+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
pub mod component;
2+
pub mod prover;
3+
4+
pub use component::{Claim, Component, Eval, InteractionClaim};
5+
pub use prover::ClaimGenerator;
6+
7+
#[cfg(test)]
8+
mod tests {
9+
10+
use itertools::{chain, Itertools};
11+
use num_traits::Zero;
12+
use rand::rngs::SmallRng;
13+
use rand::{Rng, SeedableRng};
14+
use stwo_prover::constraint_framework::{
15+
FrameworkComponent, FrameworkEval, TraceLocationAllocator,
16+
};
17+
use stwo_prover::core::backend::simd::qm31::PackedSecureField;
18+
use stwo_prover::core::backend::simd::SimdBackend;
19+
use stwo_prover::core::channel::Blake2sChannel;
20+
use stwo_prover::core::fields::m31::M31;
21+
use stwo_prover::core::fields::qm31::QM31;
22+
use stwo_prover::core::pcs::{CommitmentSchemeProver, PcsConfig};
23+
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};
24+
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
25+
26+
use super::*;
27+
use crate::components::jnz_opcode::component::INSTRUCTION_BASE;
28+
use crate::components::memory;
29+
use crate::input::instructions::VmState;
30+
use crate::relations;
31+
32+
#[test]
33+
fn test_jnz_opcode() {
34+
const LOG_HEIGHT: u32 = 8;
35+
const LOG_BLOWUP_FACTOR: u32 = 1;
36+
37+
let mut rng = SmallRng::seed_from_u64(0);
38+
39+
#[allow(clippy::erasing_op, clippy::identity_op)]
40+
let jnz_rel_ap = INSTRUCTION_BASE + M31(1 + 0 * 2);
41+
#[allow(clippy::erasing_op, clippy::identity_op)]
42+
let jnz_abs_fp = INSTRUCTION_BASE + M31(0 + 1 * 2);
43+
44+
///
45+
///
46+
/// jnz
47+
///
48+
/// ap: 8
49+
///
50+
///
51+
///
52+
let mut memory_claim_generator = memory::ClaimGenerator {
53+
values: vec![PackedSecureField::from_array([
54+
QM31::from_m31_array([jnz_rel_ap, M31(2), M31(3), M31(0)]),
55+
QM31::from_m31_array([jnz_abs_fp, M31(1), M31(5), M31(0)]),
56+
QM31::zero(),
57+
QM31::zero(),
58+
M31(2).into(),
59+
QM31::zero(),
60+
QM31::zero(),
61+
QM31::zero(),
62+
QM31::zero(),
63+
QM31::zero(),
64+
QM31::zero(),
65+
QM31::zero(),
66+
QM31::zero(),
67+
QM31::zero(),
68+
QM31::zero(),
69+
QM31::zero(),
70+
])],
71+
// Dummy multiplicities
72+
multiplicities: vec![1; 16],
73+
};
74+
75+
let claim_generator = ClaimGenerator::new(
76+
chain!(
77+
vec![
78+
VmState {
79+
pc: 0,
80+
ap: 2,
81+
fp: 3,
82+
};
83+
128
84+
],
85+
vec![
86+
VmState {
87+
pc: 1,
88+
ap: 2,
89+
fp: 3,
90+
};
91+
128
92+
]
93+
)
94+
.collect_vec(),
95+
);
96+
97+
let twiddles = SimdBackend::precompute_twiddles(
98+
CanonicCoset::new(LOG_HEIGHT + LOG_BLOWUP_FACTOR)
99+
.circle_domain()
100+
.half_coset,
101+
);
102+
103+
let channel = &mut Blake2sChannel::default();
104+
let config = PcsConfig::default();
105+
let commitment_scheme =
106+
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
107+
config, &twiddles,
108+
);
109+
110+
// Preprocessed.
111+
let tree_builder = commitment_scheme.tree_builder();
112+
tree_builder.commit(channel);
113+
114+
let mut tree_builder = commitment_scheme.tree_builder();
115+
let (claim, interaction_claim_generator) =
116+
claim_generator.write_trace(&mut tree_builder, &mut memory_claim_generator);
117+
118+
tree_builder.commit(channel);
119+
let mut tree_builder = commitment_scheme.tree_builder();
120+
121+
let memory_relation = relations::MemoryRelation::draw(channel);
122+
let state_relation = relations::StateRelation::draw(channel);
123+
let interaction_claim = interaction_claim_generator.write_interaction_trace(
124+
&mut tree_builder,
125+
&memory_relation,
126+
&state_relation,
127+
);
128+
tree_builder.commit(channel);
129+
130+
let trace_location_allocator = &mut TraceLocationAllocator::default();
131+
let component = FrameworkComponent::new(
132+
trace_location_allocator,
133+
Eval::new(claim, memory_relation, state_relation),
134+
interaction_claim.claimed_sum,
135+
);
136+
137+
let trace_polys = commitment_scheme
138+
.trees
139+
.as_ref()
140+
.map(|t| t.polynomials.iter().cloned().collect_vec());
141+
142+
stwo_prover::constraint_framework::assert_constraints(
143+
&trace_polys,
144+
CanonicCoset::new(LOG_HEIGHT),
145+
|eval| {
146+
component.evaluate(eval);
147+
},
148+
interaction_claim.claimed_sum,
149+
)
150+
}
151+
}

0 commit comments

Comments
 (0)