Skip to content

Commit ff78be5

Browse files
author
Gilad Chase
committed
Addap/jmprel/jmpabs add prover part
1 parent 92393ef commit ff78be5

File tree

3 files changed

+206
-5
lines changed

3 files changed

+206
-5
lines changed

crates/prover/src/components/addap_jmpabs_jmprel_opcode/component.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ use crate::utils::{Selector, SelectorTrait};
1717
pub const N_TRACE_CELLS: usize = 7;
1818

1919
// Assumes INSTRUCTION_BASE=K such that:
20-
// ```
21-
// addap_imm = K
22-
// jmp_abs_imm = K + 1
23-
// jmp_rel_imm = K + 2
24-
/// ```
20+
/// `
21+
/// addap_imm = K
22+
/// jmp_abs_imm = K + 1
23+
/// jmp_rel_imm = K + 2
24+
/// `
2525
// TODO: organize opcodes so that K will work as detailed above, instead of just 0.
2626
pub const INSTRUCTION_BASE: M31 = M31::from_u32_unchecked(0);
2727

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod component;
2+
pub mod prover;
23
pub use component::{Claim, Component, Eval, InteractionClaim};
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
use itertools::{zip_eq, Itertools};
2+
use num_traits::One;
3+
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
4+
use stwo_air_utils::trace::component_trace::ComponentTrace;
5+
use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized};
6+
use stwo_prover::constraint_framework::logup::LogupTraceGenerator;
7+
use stwo_prover::constraint_framework::Relation;
8+
use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
9+
use stwo_prover::core::backend::simd::qm31::PackedQM31;
10+
use stwo_prover::core::backend::simd::SimdBackend;
11+
use stwo_prover::core::fields::m31::M31;
12+
use stwo_prover::core::pcs::TreeBuilder;
13+
use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order;
14+
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
15+
16+
use super::component::{Claim, InteractionClaim};
17+
use crate::components::addap_jmpabs_jmprel_opcode::component::INSTRUCTION_BASE;
18+
use crate::components::memory;
19+
use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE};
20+
use crate::utils::prover::decode_opcode;
21+
use crate::utils::types::{CasmState, PackedCasmState};
22+
use crate::utils::{Selector, SelectorTrait};
23+
24+
const N_TRACE_COLUMNS: usize = 7;
25+
const N_MEMORY_LOOKUPS: usize = 1;
26+
const N_STATE_LOOKUPS: usize = 2;
27+
28+
#[derive(Debug)]
29+
pub struct ClaimGenerator {
30+
pub inputs: Vec<PackedCasmState>,
31+
}
32+
impl ClaimGenerator {
33+
pub fn new(mut inputs: Vec<CasmState>) -> Self {
34+
assert!(!inputs.is_empty());
35+
36+
// TODO(spapini): Split to multiple components.
37+
let size = inputs.len().next_power_of_two();
38+
inputs.resize(size, inputs[0]);
39+
40+
let inputs = inputs
41+
.into_iter()
42+
.array_chunks::<N_LANES>()
43+
.map(|chunk| PackedCasmState {
44+
pc: PackedM31::from_array(std::array::from_fn(|i| {
45+
M31::from_u32_unchecked(chunk[i].pc)
46+
})),
47+
ap: PackedM31::from_array(std::array::from_fn(|i| {
48+
M31::from_u32_unchecked(chunk[i].ap)
49+
})),
50+
fp: PackedM31::from_array(std::array::from_fn(|i| {
51+
M31::from_u32_unchecked(chunk[i].fp)
52+
})),
53+
})
54+
.collect_vec();
55+
Self { inputs }
56+
}
57+
58+
pub fn write_trace(
59+
mut self,
60+
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
61+
memory_trace_generator: &mut memory::ClaimGenerator,
62+
) -> (Claim, InteractionClaimGenerator) {
63+
let (trace, lookup_data) = write_trace_simd(&self.inputs, memory_trace_generator);
64+
65+
let n_rows = self.inputs.len();
66+
assert_ne!(n_rows, 0);
67+
let size = std::cmp::max(n_rows.next_power_of_two(), N_LANES);
68+
let need_padding = n_rows != size;
69+
70+
if need_padding {
71+
self.inputs
72+
.resize(size, self.inputs.first().unwrap().clone());
73+
bit_reverse_coset_to_circle_domain_order(&mut self.inputs);
74+
}
75+
lookup_data.memory.iter().for_each(|c| {
76+
c.iter()
77+
.for_each(|v| memory_trace_generator.add_inputs_simd(&v[0]))
78+
});
79+
tree_builder.extend_evals(trace.to_evals());
80+
(
81+
Claim { n_rows },
82+
InteractionClaimGenerator {
83+
n_rows,
84+
lookup_data,
85+
},
86+
)
87+
}
88+
}
89+
#[derive(Debug, Uninitialized, IterMut, ParIterMut)]
90+
pub struct LookupData {
91+
pub memory: [Vec<[PackedM31; N_MEMORY_ELEMS]>; N_MEMORY_LOOKUPS],
92+
pub state: [Vec<[PackedM31; STATE_SIZE]>; N_STATE_LOOKUPS],
93+
}
94+
95+
#[derive(Debug)]
96+
pub struct InteractionClaimGenerator {
97+
pub n_rows: usize,
98+
pub lookup_data: LookupData,
99+
}
100+
101+
impl InteractionClaimGenerator {
102+
pub fn write_interaction_trace(
103+
&self,
104+
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
105+
memory_relation: &MemoryRelation,
106+
state_relation: &StateRelation,
107+
) -> InteractionClaim {
108+
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
109+
let mut logup_gen = LogupTraceGenerator::new(log_size);
110+
111+
let mut col0 = logup_gen.new_col();
112+
let state_use = &self.lookup_data.state[0];
113+
let read_pc = &self.lookup_data.memory[0];
114+
for (i, (x, y)) in zip_eq(state_use, read_pc).enumerate() {
115+
let denom_x: PackedQM31 = state_relation.combine(x);
116+
let denom_y: PackedQM31 = memory_relation.combine(y);
117+
118+
col0.write_frac(i, denom_x + denom_y, denom_x * denom_y)
119+
}
120+
col0.finalize_col();
121+
122+
let mut col_gen = logup_gen.new_col();
123+
let state_yield = &self.lookup_data.state[1];
124+
for (i, values) in state_yield.iter().enumerate() {
125+
let denom: PackedQM31 = state_relation.combine(values);
126+
col_gen.write_frac(i, -PackedQM31::one(), denom);
127+
}
128+
col_gen.finalize_col();
129+
130+
let (trace, total_sum, claimed_sum) = if self.n_rows == 1 << log_size {
131+
let (trace, claimed_sum) = logup_gen.finalize_last();
132+
(trace, claimed_sum, None)
133+
} else {
134+
let (trace, [total_sum, claimed_sum]) =
135+
logup_gen.finalize_at([(1 << log_size) - 1, self.n_rows - 1]);
136+
(trace, total_sum, Some((claimed_sum, self.n_rows - 1)))
137+
};
138+
tree_builder.extend_evals(trace);
139+
140+
InteractionClaim {
141+
log_size,
142+
logup_sums: (total_sum, claimed_sum),
143+
}
144+
}
145+
}
146+
147+
// add_ap_ trace row:
148+
// pc | ap | fp | trit (addap, jmp_abs, jmp_rel) | imm | res_pc | res_ap
149+
fn write_trace_simd(
150+
inputs: &[PackedCasmState],
151+
memory_trace_generator: &memory::ClaimGenerator,
152+
) -> (ComponentTrace<N_TRACE_COLUMNS>, LookupData) {
153+
let log_n_packed_rows = inputs.len().ilog2();
154+
let log_size = log_n_packed_rows + LOG_N_LANES;
155+
let (mut trace, mut lookup_data) = unsafe {
156+
(
157+
ComponentTrace::<N_TRACE_COLUMNS>::uninitialized(log_size),
158+
LookupData::uninitialized(log_n_packed_rows),
159+
)
160+
};
161+
162+
trace
163+
.par_iter_mut()
164+
.zip(inputs.par_iter())
165+
.zip(lookup_data.par_iter_mut())
166+
.for_each(|((row, opcode_input), lookup_data)| {
167+
// Initial state.
168+
let pc = opcode_input.pc;
169+
let ap = opcode_input.ap;
170+
let fp = opcode_input.fp;
171+
*row[0] = pc;
172+
*row[1] = ap;
173+
*row[2] = fp;
174+
*lookup_data.state[0] = [pc, ap, fp];
175+
176+
// Decode insturction.
177+
let [opcode, imm, off0, off1] =
178+
memory_trace_generator.deduce_output(pc).into_packed_m31s();
179+
*lookup_data.memory[0] = [pc, opcode, imm, off0, off1];
180+
181+
let [op_type] = decode_opcode(INSTRUCTION_BASE, opcode, [3]);
182+
183+
*row[3] = op_type;
184+
*row[4] = imm;
185+
186+
// Calc new state.
187+
let new_pc = Selector::select(
188+
&op_type,
189+
[&(pc + PackedM31::broadcast(M31::one())), &imm, &(pc + imm)],
190+
);
191+
let new_ap = Selector::select(&op_type, [&(ap + imm), &ap, &ap]);
192+
193+
*row[5] = new_pc;
194+
*row[6] = new_ap;
195+
196+
*lookup_data.state[1] = [new_pc, new_ap, fp];
197+
});
198+
199+
(trace, lookup_data)
200+
}

0 commit comments

Comments
 (0)