Skip to content

Commit 9db86bd

Browse files
committed
Add jnz component
1 parent f0f2489 commit 9db86bd

File tree

4 files changed

+561
-0
lines changed

4 files changed

+561
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
use itertools::{chain, Itertools};
2+
use num_traits::One;
3+
use serde::{Deserialize, Serialize};
4+
use stwo_prover::constraint_framework::{
5+
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
6+
};
7+
use stwo_prover::core::backend::simd::m31::LOG_N_LANES;
8+
use stwo_prover::core::channel::Channel;
9+
use stwo_prover::core::fields::m31::M31;
10+
use stwo_prover::core::fields::qm31::SecureField;
11+
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
12+
use stwo_prover::core::pcs::TreeVec;
13+
14+
use crate::relations::{MemoryRelation, StateRelation};
15+
use crate::utils::component::{decode_opcode, is_bit};
16+
use crate::utils::{Selector, SelectorTrait};
17+
18+
pub const N_TRACE_CELLS: usize = 19;
19+
20+
// Assumes INSTRUCTION_BASE=K such that:
21+
/// `
22+
/// addap_imm = K
23+
/// jmp_abs_imm = K + 1
24+
/// jmp_rel_imm = K + 2
25+
/// `
26+
// TODO: organize opcodes so that K will work as detailed above, instead of just 0.
27+
pub const INSTRUCTION_BASE: M31 = M31::from_u32_unchecked(0);
28+
29+
pub type Component = FrameworkComponent<Eval>;
30+
31+
pub struct Eval {
32+
pub claim: Claim,
33+
pub memory_lookup: MemoryRelation,
34+
pub state_lookup: StateRelation,
35+
}
36+
37+
impl FrameworkEval for Eval {
38+
fn log_size(&self) -> u32 {
39+
std::cmp::max(self.claim.n_rows.next_power_of_two().ilog2(), LOG_N_LANES)
40+
}
41+
42+
fn max_constraint_log_degree_bound(&self) -> u32 {
43+
self.log_size() + 1
44+
}
45+
46+
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
47+
let state = std::array::from_fn(|_| eval.next_trace_mask());
48+
// Use initial state.
49+
eval.add_to_relation(RelationEntry::new(&self.state_lookup, E::EF::one(), &state));
50+
let [pc, ap, fp] = state;
51+
52+
// Assert flags are in range.
53+
let [op_type, reg] = std::array::from_fn(|_| eval.next_trace_mask());
54+
eval.add_constraint(is_bit::<E>(&op_type));
55+
eval.add_constraint(is_bit::<E>(&reg));
56+
57+
// Check instruction.
58+
let opcode = decode_opcode(
59+
INSTRUCTION_BASE.into(),
60+
&[
61+
(op_type.clone(), 2), // [jmp abs, jmp rel]
62+
(reg.clone(), 2), // [ap, fp]
63+
],
64+
);
65+
66+
let [off, imm] = std::array::from_fn(|_| eval.next_trace_mask());
67+
68+
eval.add_to_relation(RelationEntry::new(
69+
&self.memory_lookup,
70+
E::EF::one(),
71+
&[
72+
pc.clone(),
73+
opcode.clone(),
74+
off.clone(),
75+
imm.clone(),
76+
],
77+
));
78+
// Compute address.
79+
let addr = eval.next_trace_mask();
80+
eval.add_constraint(
81+
addr.clone() - (Selector::select(&reg, [&(ap.clone()), &(fp.clone())]) + off),
82+
);
83+
84+
let addr_val_arr: [E::F; 4] = std::array::from_fn(|_| eval.next_trace_mask());
85+
eval.add_to_relation(RelationEntry::new(
86+
&self.memory_lookup,
87+
E::EF::one(),
88+
&chain!([addr], addr_val_arr.clone()).collect_vec(),
89+
));
90+
let val = E::combine_ef(addr_val_arr);
91+
92+
93+
// Check jnz condition.
94+
let maybe_inverse_val = E::combine_ef(std::array::from_fn(|_| eval.next_trace_mask()));
95+
let flag = eval.next_trace_mask();
96+
eval.add_constraint(is_bit::<E>(&flag));
97+
98+
// flag == 0 iff val == 0 iff val is not invertible <=>
99+
// ==> 0 = val * (1 - flag) + (1 - val * val^{-1}) * flag
100+
eval.add_constraint(
101+
val.clone() * (E::F::one() - flag.clone())
102+
+ (E::EF::one() - val.clone() * maybe_inverse_val) * flag.clone(),
103+
);
104+
105+
// Assert new pc.
106+
// The relative branch when taken is obvious.
107+
let jmp_target_if_taken =
108+
&Selector::select(&op_type, [&imm, &(pc.clone() + imm.clone())]);
109+
110+
let new_pc = eval.next_trace_mask();
111+
112+
eval.add_constraint(
113+
new_pc.clone() - Selector::select(&flag, [&(pc + E::F::one()), jmp_target_if_taken]),
114+
);
115+
116+
// Yield final state.
117+
let new_state = [new_pc, ap, fp];
118+
eval.add_to_relation(RelationEntry::new(
119+
&self.state_lookup,
120+
-E::EF::one(),
121+
&new_state,
122+
));
123+
124+
eval.finalize_logup_in_pairs();
125+
eval
126+
}
127+
}
128+
129+
#[derive(Copy, Clone, Serialize, Deserialize)]
130+
pub struct Claim {
131+
pub n_rows: usize,
132+
}
133+
134+
impl Claim {
135+
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
136+
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
137+
let preprocessed_log_sizes = vec![log_size];
138+
let trace_log_sizes = vec![log_size; N_TRACE_CELLS];
139+
let interaction_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE * 3];
140+
TreeVec::new(vec![
141+
preprocessed_log_sizes,
142+
trace_log_sizes,
143+
interaction_log_sizes,
144+
])
145+
}
146+
147+
pub fn mix_into(&self, channel: &mut impl Channel) {
148+
channel.mix_u64(self.n_rows as u64);
149+
}
150+
}
151+
152+
#[derive(Clone, Serialize, Deserialize)]
153+
pub struct InteractionClaim {
154+
pub log_size: u32,
155+
pub claimed_sum: SecureField,
156+
}
157+
impl InteractionClaim {
158+
pub fn mix_into(&self, channel: &mut impl Channel) {
159+
channel.mix_felts(&[self.claimed_sum]);
160+
}
161+
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
pub mod component;
2+
pub mod prover;
3+
4+
pub use component::{Claim, Component, Eval, InteractionClaim};
5+
pub use prover::ClaimGenerator;
6+
7+
#[cfg(test)]
8+
mod tests {
9+
10+
use itertools::{chain, Itertools};
11+
use num_traits::Zero;
12+
use rand::rngs::SmallRng;
13+
use rand::{Rng, SeedableRng};
14+
use stwo_prover::constraint_framework::{
15+
FrameworkComponent, FrameworkEval, TraceLocationAllocator,
16+
};
17+
use stwo_prover::core::backend::simd::qm31::PackedSecureField;
18+
use stwo_prover::core::backend::simd::SimdBackend;
19+
use stwo_prover::core::channel::Blake2sChannel;
20+
use stwo_prover::core::fields::m31::M31;
21+
use stwo_prover::core::fields::qm31::QM31;
22+
use stwo_prover::core::pcs::{CommitmentSchemeProver, PcsConfig};
23+
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};
24+
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
25+
26+
use super::*;
27+
use crate::components::jnz_opcode::component::INSTRUCTION_BASE;
28+
use crate::components::memory;
29+
use crate::input::instructions::VmState;
30+
use crate::relations;
31+
32+
#[test]
33+
fn test_jnz_opcode() {
34+
const LOG_HEIGHT: u32 = 8;
35+
const LOG_BLOWUP_FACTOR: u32 = 1;
36+
37+
let mut rng = SmallRng::seed_from_u64(0);
38+
39+
#[allow(clippy::erasing_op, clippy::identity_op)]
40+
let jnz_rel_ap = INSTRUCTION_BASE + M31(1 + 0 * 2);
41+
#[allow(clippy::erasing_op, clippy::identity_op)]
42+
let jnz_abs_fp = INSTRUCTION_BASE + M31(0 + 1 * 2);
43+
44+
#[allow(clippy::erasing_op, clippy::identity_op)]
45+
let add_ap_ap_fp = INSTRUCTION_BASE + M31(0 + 0 * 2 + 1 * 4 + 1 * 8 + 0 * 16);
46+
#[allow(clippy::erasing_op, clippy::identity_op)]
47+
let mul_fp_fp_ap_appp = INSTRUCTION_BASE + M31(1 + 1 * 2 + 1 * 4 + 0 * 8 + 1 * 16);
48+
let x: QM31 = rng.gen();
49+
let y: QM31 = rng.gen();
50+
let z: QM31 = rng.gen();
51+
52+
// Initialize at pc=0, ap=2, fp=4 with:
53+
// pc -> 0: [ap] = [fp] + [fp]
54+
// 1: [fp + 1] = [fp + 2] * [ap + 1]; ap++
55+
// ap -> 2: 2X
56+
// 3: Y
57+
// fp -> 4: X
58+
// 5: Z * Y
59+
// 6: Z
60+
let mut memory_claim_generator = memory::ClaimGenerator {
61+
values: vec![PackedSecureField::from_array([
62+
QM31::from_m31_array([add_ap_ap_fp, M31(0), M31(0), M31(0)]),
63+
QM31::from_m31_array([mul_fp_fp_ap_appp, M31(1), M31(2), M31(1)]),
64+
x + x,
65+
y,
66+
x,
67+
z * y,
68+
z,
69+
QM31::zero(),
70+
QM31::zero(),
71+
QM31::zero(),
72+
QM31::zero(),
73+
QM31::zero(),
74+
QM31::zero(),
75+
QM31::zero(),
76+
QM31::zero(),
77+
QM31::zero(),
78+
])],
79+
// Dummy multiplicities
80+
multiplicities: vec![1; 16],
81+
};
82+
83+
let claim_generator = ClaimGenerator::new(
84+
chain!(
85+
vec![
86+
VmState {
87+
pc: 0,
88+
ap: 2,
89+
fp: 4,
90+
};
91+
128
92+
],
93+
vec![
94+
VmState {
95+
pc: 1,
96+
ap: 2,
97+
fp: 4,
98+
};
99+
128
100+
]
101+
)
102+
.collect_vec(),
103+
);
104+
105+
let twiddles = SimdBackend::precompute_twiddles(
106+
CanonicCoset::new(LOG_HEIGHT + LOG_BLOWUP_FACTOR)
107+
.circle_domain()
108+
.half_coset,
109+
);
110+
111+
let channel = &mut Blake2sChannel::default();
112+
let config = PcsConfig::default();
113+
let commitment_scheme =
114+
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
115+
config, &twiddles,
116+
);
117+
118+
// Preprocessed.
119+
let tree_builder = commitment_scheme.tree_builder();
120+
tree_builder.commit(channel);
121+
122+
let mut tree_builder = commitment_scheme.tree_builder();
123+
let (claim, interaction_claim_generator) =
124+
claim_generator.write_trace(&mut tree_builder, &mut memory_claim_generator);
125+
126+
tree_builder.commit(channel);
127+
let mut tree_builder = commitment_scheme.tree_builder();
128+
129+
let memory_relation = relations::MemoryRelation::draw(channel);
130+
let state_relation = relations::StateRelation::draw(channel);
131+
let interaction_claim = interaction_claim_generator.write_interaction_trace(
132+
&mut tree_builder,
133+
&memory_relation,
134+
&state_relation,
135+
);
136+
tree_builder.commit(channel);
137+
138+
let trace_location_allocator = &mut TraceLocationAllocator::default();
139+
let component = FrameworkComponent::new(
140+
trace_location_allocator,
141+
Eval::new(claim, memory_relation, state_relation),
142+
interaction_claim.claimed_sum,
143+
);
144+
145+
let trace_polys = commitment_scheme
146+
.trees
147+
.as_ref()
148+
.map(|t| t.polynomials.iter().cloned().collect_vec());
149+
150+
stwo_prover::constraint_framework::assert_constraints(
151+
&trace_polys,
152+
CanonicCoset::new(LOG_HEIGHT),
153+
|eval| {
154+
component.evaluate(eval);
155+
},
156+
interaction_claim.claimed_sum,
157+
)
158+
}
159+
}

0 commit comments

Comments
 (0)