Skip to content

Commit b0d2a1d

Browse files
author
Gilad Chase
committed
add add_mul_imm component
1 parent ff78be5 commit b0d2a1d

File tree

4 files changed

+189
-1
lines changed

4 files changed

+189
-1
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
use itertools::{chain, Itertools};
2+
use num_traits::One;
3+
use serde::{Deserialize, Serialize};
4+
use stwo_prover::constraint_framework::logup::LogupSums;
5+
use stwo_prover::constraint_framework::{
6+
EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry,
7+
};
8+
use stwo_prover::core::backend::simd::m31::LOG_N_LANES;
9+
use stwo_prover::core::channel::Channel;
10+
use stwo_prover::core::fields::m31::M31;
11+
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
12+
use stwo_prover::core::pcs::TreeVec;
13+
14+
use crate::relations::{MemoryRelation, StateRelation};
15+
use crate::utils::component::{decode_opcode, is_bit};
16+
use crate::utils::{Selector, SelectorTrait};
17+
18+
pub const N_TRACE_COLUMNS: usize = 20;
19+
// TODO(alont): set instruction bases to not overlap
20+
pub const INSTRUCTION_BASE: M31 = M31::from_u32_unchecked(0);
21+
22+
pub type Component = FrameworkComponent<Eval>;
23+
24+
#[derive(Clone)]
25+
pub struct Eval {
26+
pub claim: Claim,
27+
pub memory_lookup: MemoryRelation,
28+
pub state_lookup: StateRelation,
29+
}
30+
31+
impl Eval {
32+
pub fn new(claim: Claim, memory_lookup: MemoryRelation, state_lookup: StateRelation) -> Self {
33+
Self {
34+
claim: claim.clone(),
35+
memory_lookup,
36+
state_lookup,
37+
}
38+
}
39+
}
40+
41+
impl FrameworkEval for Eval {
42+
fn log_size(&self) -> u32 {
43+
std::cmp::max(self.claim.n_rows.next_power_of_two().ilog2(), LOG_N_LANES)
44+
}
45+
46+
fn max_constraint_log_degree_bound(&self) -> u32 {
47+
self.log_size() + 1
48+
}
49+
50+
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
51+
let state = std::array::from_fn(|_| eval.next_trace_mask());
52+
// Use initial state.
53+
eval.add_to_relation(RelationEntry::new(&self.state_lookup, E::EF::one(), &state));
54+
let [pc, ap, fp] = state;
55+
56+
// Assert flags are in range.
57+
let [op_type, lhs_flag, rhs_flag, appp] = std::array::from_fn(|_| eval.next_trace_mask());
58+
eval.add_constraint(is_bit::<E>(&op_type));
59+
eval.add_constraint(is_bit::<E>(&lhs_flag));
60+
eval.add_constraint(is_bit::<E>(&rhs_flag));
61+
eval.add_constraint(is_bit::<E>(&appp));
62+
63+
// Check instruction.
64+
let [off0, off1, imm] = std::array::from_fn(|_| eval.next_trace_mask());
65+
let opcode = decode_opcode(
66+
INSTRUCTION_BASE.into(),
67+
&[
68+
(op_type.clone(), 2), // [add, mul]
69+
(lhs_flag.clone(), 2), // [ap, fp]
70+
(rhs_flag.clone(), 2), // [ap, fp]
71+
(appp.clone(), 2), // [false, true]
72+
],
73+
);
74+
75+
eval.add_to_relation(RelationEntry::new(
76+
&self.memory_lookup,
77+
E::EF::one(),
78+
&[
79+
pc.clone(),
80+
opcode.clone(),
81+
off0.clone(),
82+
off1.clone(),
83+
imm.clone(),
84+
],
85+
));
86+
87+
// Compute addresses.
88+
let [lhs_address, rhs_address] = std::array::from_fn(|_| eval.next_trace_mask());
89+
90+
eval.add_constraint(
91+
lhs_address.clone() - (Selector::select(&lhs_flag, [&ap, &fp]) + off0.clone()),
92+
);
93+
eval.add_constraint(
94+
rhs_address.clone() - (Selector::select(&rhs_flag, [&ap, &fp]) + off1.clone()),
95+
);
96+
97+
// Read memory.
98+
let lhs_val_arr: [E::F; 4] = std::array::from_fn(|_| eval.next_trace_mask());
99+
let rhs_val_arr: [E::F; 4] = std::array::from_fn(|_| eval.next_trace_mask());
100+
101+
eval.add_to_relation(RelationEntry::new(
102+
&self.memory_lookup,
103+
E::EF::one(),
104+
&chain!([lhs_address], lhs_val_arr.clone()).collect_vec(),
105+
));
106+
107+
eval.add_to_relation(RelationEntry::new(
108+
&self.memory_lookup,
109+
E::EF::one(),
110+
&chain!([rhs_address], rhs_val_arr.clone()).collect_vec(),
111+
));
112+
113+
let lhs_val = E::combine_ef(lhs_val_arr);
114+
let rhs_val = E::combine_ef(rhs_val_arr);
115+
116+
// Apply operation.
117+
eval.add_constraint(
118+
lhs_val
119+
- (Selector::select(
120+
&E::EF::from(op_type),
121+
[
122+
&(rhs_val.clone() + imm.clone()),
123+
&(rhs_val.clone() * imm.clone()),
124+
],
125+
)),
126+
);
127+
128+
// Yield final state.
129+
let new_state = [pc + E::F::one(), ap + appp, fp];
130+
eval.add_to_relation(RelationEntry::new(
131+
&self.state_lookup,
132+
-E::EF::one(),
133+
&new_state,
134+
));
135+
136+
eval.finalize_logup_in_pairs();
137+
eval
138+
}
139+
}
140+
141+
#[derive(Clone, Serialize, Deserialize)]
142+
pub struct Claim {
143+
pub n_rows: usize,
144+
}
145+
146+
impl Claim {
147+
pub fn mix_into(&self, channel: &mut impl Channel) {
148+
channel.mix_u64(self.n_rows as u64);
149+
}
150+
151+
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
152+
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
153+
let preprocessed_log_sizes = vec![log_size];
154+
let interaction_1_log_sizes = vec![log_size; N_TRACE_COLUMNS];
155+
let interaction_2_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE * 3];
156+
TreeVec::new(vec![
157+
preprocessed_log_sizes,
158+
interaction_1_log_sizes,
159+
interaction_2_log_sizes,
160+
])
161+
}
162+
}
163+
164+
#[derive(Clone, Serialize, Deserialize)]
165+
pub struct InteractionClaim {
166+
pub log_size: u32,
167+
pub logup_sums: LogupSums,
168+
}
169+
170+
impl InteractionClaim {
171+
pub fn mix_into(&self, channel: &mut impl Channel) {
172+
let (total_sum, claimed_sum) = self.logup_sums;
173+
channel.mix_felts(&[total_sum]);
174+
if let Some(claimed_sum) = claimed_sum {
175+
channel.mix_felts(&[claimed_sum.0]);
176+
channel.mix_u64(claimed_sum.1 as u64);
177+
}
178+
}
179+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod component;

crates/prover/src/components/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod add_mul_imm_opcode;
12
pub mod add_mul_opcode;
23
pub mod addap_jmpabs_jmprel_opcode;
34
pub mod memory;

crates/prover/src/utils/component.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@ where
2020
opcode
2121
}
2222

23-
/// Assert that `flag` is a trit (a digit in {0,1,2}).
23+
/// Create a constraint asserting that `flag` is a bit.
24+
pub fn is_bit<E: EvalAtRow>(flag: &E::F) -> E::F {
25+
let f = || flag.clone();
26+
// f^2 - f
27+
f() * f() - f()
28+
}
29+
30+
/// Create a constraint asserting that `flag` is a trit (a digit in {0,1,2}).
2431
pub fn is_trit<E: EvalAtRow>(flag: &E::F) -> E::F {
2532
let two = E::F::from(BaseField::from_u32_unchecked(2));
2633
let three = E::F::from(BaseField::from_u32_unchecked(3));

0 commit comments

Comments
 (0)