Skip to content

Commit 6439d8f

Browse files
author
Gilad Chase
committed
add rest of add mul imm opcode
1 parent b0d2a1d commit 6439d8f

File tree

2 files changed

+241
-0
lines changed

2 files changed

+241
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
pub mod component;
2+
pub mod prover;
3+
4+
pub use component::{Claim, Component, Eval, InteractionClaim};
5+
pub use prover::ClaimGenerator;
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
use itertools::{zip_eq, Itertools};
2+
use num_traits::One;
3+
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
4+
use stwo_air_utils::trace::component_trace::ComponentTrace;
5+
use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized};
6+
use stwo_prover::constraint_framework::logup::LogupTraceGenerator;
7+
use stwo_prover::constraint_framework::Relation;
8+
use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
9+
use stwo_prover::core::backend::simd::qm31::PackedQM31;
10+
use stwo_prover::core::backend::simd::SimdBackend;
11+
use stwo_prover::core::fields::m31::M31;
12+
use stwo_prover::core::pcs::TreeBuilder;
13+
use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order;
14+
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
15+
16+
use super::component::{Claim, InteractionClaim, INSTRUCTION_BASE};
17+
use crate::components::add_mul_opcode::component::N_TRACE_COLUMNS;
18+
use crate::components::memory;
19+
use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE};
20+
use crate::utils::prover::decode_opcode;
21+
use crate::utils::types::{CasmState, PackedCasmState};
22+
use crate::utils::{Selector, SelectorTrait};
23+
24+
const N_MEMORY_LOOKUPS: usize = 3;
25+
const N_STATE_LOOKUPS: usize = 2;
26+
27+
#[derive(Debug)]
28+
pub struct ClaimGenerator {
29+
pub inputs: Vec<PackedCasmState>,
30+
}
31+
impl ClaimGenerator {
32+
pub fn new(mut inputs: Vec<CasmState>) -> Self {
33+
assert!(!inputs.is_empty());
34+
35+
// TODO(spapini): Split to multiple components.
36+
let size = inputs.len().next_power_of_two();
37+
inputs.resize(size, inputs[0]);
38+
39+
let inputs = inputs
40+
.into_iter()
41+
.array_chunks::<N_LANES>()
42+
.map(|chunk| PackedCasmState {
43+
pc: PackedM31::from_array(std::array::from_fn(|i| {
44+
M31::from_u32_unchecked(chunk[i].pc)
45+
})),
46+
ap: PackedM31::from_array(std::array::from_fn(|i| {
47+
M31::from_u32_unchecked(chunk[i].ap)
48+
})),
49+
fp: PackedM31::from_array(std::array::from_fn(|i| {
50+
M31::from_u32_unchecked(chunk[i].fp)
51+
})),
52+
})
53+
.collect_vec();
54+
Self { inputs }
55+
}
56+
57+
pub fn write_trace(
58+
mut self,
59+
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
60+
memory_trace_generator: &mut memory::ClaimGenerator,
61+
) -> (Claim, InteractionClaimGenerator) {
62+
let (trace, lookup_data) = write_trace_simd(&self.inputs, memory_trace_generator);
63+
64+
let n_rows = self.inputs.len();
65+
assert_ne!(n_rows, 0);
66+
let size = std::cmp::max(n_rows.next_power_of_two(), N_LANES);
67+
let need_padding = n_rows != size;
68+
69+
if need_padding {
70+
self.inputs
71+
.resize(size, self.inputs.first().unwrap().clone());
72+
bit_reverse_coset_to_circle_domain_order(&mut self.inputs);
73+
}
74+
lookup_data.memory.iter().for_each(|c| {
75+
c.iter()
76+
.for_each(|v| memory_trace_generator.add_inputs_simd(&v[0]))
77+
});
78+
tree_builder.extend_evals(trace.to_evals());
79+
(
80+
Claim { n_rows },
81+
InteractionClaimGenerator {
82+
n_rows,
83+
lookup_data,
84+
},
85+
)
86+
}
87+
}
88+
89+
#[derive(Debug, Uninitialized, IterMut, ParIterMut)]
90+
pub struct LookupData {
91+
pub memory: [Vec<[PackedM31; N_MEMORY_ELEMS]>; N_MEMORY_LOOKUPS],
92+
pub state: [Vec<[PackedM31; STATE_SIZE]>; N_STATE_LOOKUPS],
93+
}
94+
95+
pub struct InteractionClaimGenerator {
96+
pub n_rows: usize,
97+
pub lookup_data: LookupData,
98+
}
99+
100+
impl InteractionClaimGenerator {
101+
pub fn with_capacity(capacity: usize) -> Self {
102+
Self {
103+
n_rows: capacity,
104+
memory: [
105+
Vec::with_capacity(capacity),
106+
Vec::with_capacity(capacity),
107+
Vec::with_capacity(capacity),
108+
Vec::with_capacity(capacity),
109+
],
110+
state: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)],
111+
}
112+
}
113+
114+
pub fn write_interaction_trace(
115+
&self,
116+
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
117+
memory_relation: &MemoryRelation,
118+
state_relation: &StateRelation,
119+
) -> InteractionClaim {
120+
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
121+
let mut logup_gen = LogupTraceGenerator::new(log_size);
122+
123+
let mut col0 = logup_gen.new_col();
124+
let state_use = &self.lookup_data.state[0];
125+
let read_pc = &self.lookup_data.memory[0];
126+
for (i, (x, y)) in zip_eq(state_use, read_pc).enumerate() {
127+
let denom_x: PackedQM31 = state_relation.combine(x);
128+
let denom_y: PackedQM31 = memory_relation.combine(y);
129+
130+
col0.write_frac(i, denom_x + denom_y, denom_x * denom_y)
131+
}
132+
col0.finalize_col();
133+
134+
let mut col_gen = logup_gen.new_col();
135+
let state_yield = &self.lookup_data.state[1];
136+
for (i, values) in state_yield.iter().enumerate() {
137+
let denom: PackedQM31 = state_relation.combine(values);
138+
col_gen.write_frac(i, -PackedQM31::one(), denom);
139+
}
140+
col_gen.finalize_col();
141+
142+
let (trace, total_sum, claimed_sum) = if self.n_rows == 1 << log_size {
143+
let (trace, claimed_sum) = logup_gen.finalize_last();
144+
(trace, claimed_sum, None)
145+
} else {
146+
let (trace, [total_sum, claimed_sum]) =
147+
logup_gen.finalize_at([(1 << log_size) - 1, self.n_rows - 1]);
148+
(trace, total_sum, Some((claimed_sum, self.n_rows - 1)))
149+
};
150+
tree_builder.extend_evals(trace);
151+
152+
InteractionClaim {
153+
log_size,
154+
logup_sums: (total_sum, claimed_sum),
155+
}
156+
}
157+
}
158+
159+
// Add / Mul trace row:
160+
// | State (3) | flags (4) | offsets (2) | imm (1) | addrs (2) | values (2 * 4) |
161+
fn write_trace_simd(
162+
inputs: &[PackedCasmState],
163+
memory_trace_generator: &memory::ClaimGenerator,
164+
) -> (ComponentTrace<N_TRACE_COLUMNS>, LookupData) {
165+
let log_n_packed_rows = inputs.len().ilog2();
166+
let log_size = log_n_packed_rows + LOG_N_LANES;
167+
let (mut trace, mut lookup_data) = unsafe {
168+
(
169+
ComponentTrace::<N_TRACE_COLUMNS>::uninitialized(log_size),
170+
LookupData::uninitialized(log_n_packed_rows),
171+
)
172+
};
173+
174+
trace
175+
.par_iter_mut()
176+
.zip(inputs.par_iter())
177+
.zip(lookup_data.par_iter_mut())
178+
.for_each(|((row, opcode_input), lookup_data)| {
179+
// Initial state.
180+
let pc = opcode_input.pc;
181+
let ap = opcode_input.ap;
182+
let fp = opcode_input.fp;
183+
*row[0] = pc;
184+
*row[1] = ap;
185+
*row[2] = fp;
186+
*lookup_data.state[0] = [pc, ap, fp];
187+
188+
// Decode insturction.
189+
let [opcode, off0, off1, imm] =
190+
memory_trace_generator.deduce_output(pc).into_packed_m31s();
191+
*lookup_data.memory[0] = [pc, opcode, off0, off1, imm];
192+
193+
let [op_type, lhs_flag, rhs_flag, appp] =
194+
decode_opcode(INSTRUCTION_BASE, opcode, [2, 2, 2, 2]);
195+
*row[3] = op_type;
196+
*row[4] = lhs_flag;
197+
*row[5] = rhs_flag;
198+
*row[6] = appp;
199+
200+
// Offsets
201+
*row[8] = off0;
202+
*row[9] = off1;
203+
*row[10] = imm;
204+
205+
// Addresses
206+
let lhs_addr = Selector::select(&lhs_flag, [&ap, &fp]) + off0;
207+
let rhs_addr = Selector::select(&rhs_flag, [&ap, &fp]) + off1;
208+
209+
*row[11] = lhs_addr;
210+
*row[12] = rhs_addr;
211+
212+
let [lhs0, lhs1, lhs2, lhs3] = memory_trace_generator
213+
.deduce_output(lhs_addr)
214+
.into_packed_m31s();
215+
let [rhs0, rhs1, rhs2, rhs3] = memory_trace_generator
216+
.deduce_output(rhs_addr)
217+
.into_packed_m31s();
218+
219+
*row[13] = lhs0;
220+
*row[14] = lhs1;
221+
*row[15] = lhs2;
222+
*row[16] = lhs3;
223+
224+
*row[17] = rhs0;
225+
*row[18] = rhs1;
226+
*row[18] = rhs2;
227+
*row[20] = rhs3;
228+
229+
*lookup_data.memory[0] = [pc, opcode, imm, off0, off1];
230+
*lookup_data.memory[1] = [lhs_addr, lhs0, lhs1, lhs2, lhs3];
231+
*lookup_data.memory[2] = [rhs_addr, rhs0, rhs1, rhs2, rhs3];
232+
233+
*lookup_data.state[1] = [pc + PackedM31::one(), ap + appp, fp];
234+
});
235+
236+
(trace, lookup_data)
237+
}

0 commit comments

Comments
 (0)