Skip to content

Commit 878d5d5

Browse files
committed
Parallelized add_mul trace gen.
1 parent 69d1822 commit 878d5d5

File tree

1 file changed

+113
-147
lines changed
  • crates/prover/src/components/add_mul_opcode

1 file changed

+113
-147
lines changed
Lines changed: 113 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use itertools::{zip_eq, Itertools};
22
use num_traits::One;
3+
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
4+
use stwo_air_utils::trace::component_trace::ComponentTrace;
5+
use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized};
36
use stwo_prover::constraint_framework::logup::LogupTraceGenerator;
47
use stwo_prover::constraint_framework::Relation;
58
use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
69
use stwo_prover::core::backend::simd::qm31::PackedQM31;
710
use stwo_prover::core::backend::simd::SimdBackend;
8-
use stwo_prover::core::backend::{Col, Column};
911
use stwo_prover::core::fields::m31::M31;
1012
use stwo_prover::core::pcs::TreeBuilder;
11-
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
12-
use stwo_prover::core::poly::BitReversedOrder;
1313
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
1414

1515
use super::component::{Claim, InteractionClaim, INSTRUCTION_BASE};
@@ -63,51 +63,48 @@ impl ClaimGenerator {
6363
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
6464
memory_trace_generator: &mut memory::ClaimGenerator,
6565
) -> (Claim, InteractionClaimGenerator) {
66-
let (trace, interaction_claim_generator) =
67-
write_trace_simd(&self.inputs, memory_trace_generator);
68-
interaction_claim_generator.memory.iter().for_each(|c| {
66+
let (trace, lookup_data) = write_trace_simd(&self.inputs, memory_trace_generator);
67+
lookup_data.memory.iter().for_each(|c| {
6968
c.iter()
7069
.for_each(|v| memory_trace_generator.add_inputs_simd(&v[0]))
7170
});
72-
tree_builder.extend_evals(trace);
73-
let claim = Claim {
74-
n_rows: self.inputs.len() * N_LANES,
75-
};
76-
(claim, interaction_claim_generator)
71+
tree_builder.extend_evals(trace.to_evals());
72+
let n_rows = self.inputs.len() * N_LANES;
73+
(
74+
Claim { n_rows },
75+
InteractionClaimGenerator {
76+
n_rows,
77+
lookup_data,
78+
},
79+
)
7780
}
7881
}
7982

80-
pub struct InteractionClaimGenerator {
81-
pub n_rows: usize,
83+
#[derive(Debug, Uninitialized, IterMut, ParIterMut)]
84+
pub struct LookupData {
8285
pub memory: [Vec<[PackedM31; N_MEMORY_ELEMS]>; N_MEMORY_LOOKUPS],
8386
pub state: [Vec<[PackedM31; STATE_SIZE]>; N_STATE_LOOKUPS],
8487
}
85-
impl InteractionClaimGenerator {
86-
pub fn with_capacity(capacity: usize) -> Self {
87-
Self {
88-
n_rows: capacity,
89-
memory: [
90-
Vec::with_capacity(capacity),
91-
Vec::with_capacity(capacity),
92-
Vec::with_capacity(capacity),
93-
Vec::with_capacity(capacity),
94-
],
95-
state: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)],
96-
}
97-
}
9888

89+
#[derive(Debug)]
90+
pub struct InteractionClaimGenerator {
91+
pub n_rows: usize,
92+
pub lookup_data: LookupData,
93+
}
94+
95+
impl InteractionClaimGenerator {
9996
pub fn write_interaction_trace(
10097
&self,
10198
tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>,
10299
memory_relation: &MemoryRelation,
103100
state_relation: &StateRelation,
104101
) -> InteractionClaim {
105-
let log_size = self.memory[0].len().ilog2() + LOG_N_LANES;
102+
let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES);
106103
let mut logup_gen = LogupTraceGenerator::new(log_size);
107104

108105
let mut col0 = logup_gen.new_col();
109-
let state_use = &self.state[0];
110-
let read_pc = &self.memory[0];
106+
let state_use = &self.lookup_data.state[0];
107+
let read_pc = &self.lookup_data.memory[0];
111108
for (i, (x, y)) in zip_eq(state_use, read_pc).enumerate() {
112109
let denom_x: PackedQM31 = state_relation.combine(x);
113110
let denom_y: PackedQM31 = memory_relation.combine(y);
@@ -117,8 +114,8 @@ impl InteractionClaimGenerator {
117114
col0.finalize_col();
118115

119116
let mut col1 = logup_gen.new_col();
120-
let read_dst = &self.memory[1];
121-
let read_lhs = &self.memory[2];
117+
let read_dst = &self.lookup_data.memory[1];
118+
let read_lhs = &self.lookup_data.memory[2];
122119
for (i, (x, y)) in zip_eq(read_dst, read_lhs).enumerate() {
123120
let denom_x: PackedQM31 = memory_relation.combine(x);
124121
let denom_y: PackedQM31 = memory_relation.combine(y);
@@ -128,8 +125,8 @@ impl InteractionClaimGenerator {
128125
col1.finalize_col();
129126

130127
let mut col2 = logup_gen.new_col();
131-
let read_rhs = &self.memory[3];
132-
let state_yield = &self.state[1];
128+
let read_rhs = &self.lookup_data.memory[3];
129+
let state_yield = &self.lookup_data.state[1];
133130
for (i, (x, y)) in zip_eq(read_rhs, state_yield).enumerate() {
134131
let denom_x: PackedQM31 = memory_relation.combine(x);
135132
let denom_y: PackedQM31 = state_relation.combine(y);
@@ -151,119 +148,88 @@ impl InteractionClaimGenerator {
151148
fn write_trace_simd(
152149
inputs: &[PackedVmState],
153150
memory_trace_generator: &memory::ClaimGenerator,
154-
) -> (
155-
Vec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
156-
InteractionClaimGenerator,
157-
) {
158-
let n_trace_columns = N_TRACE_COLUMNS;
159-
let mut trace_values = (0..n_trace_columns)
160-
.map(|_| Col::<SimdBackend, M31>::zeros(inputs.len() * N_LANES))
161-
.collect_vec();
162-
let mut sub_components_inputs = InteractionClaimGenerator::with_capacity(inputs.len());
163-
inputs.iter().enumerate().for_each(|(i, input)| {
164-
write_trace_row(
165-
&mut trace_values,
166-
input,
167-
i,
168-
&mut sub_components_inputs,
169-
memory_trace_generator,
170-
);
171-
});
172-
173-
let trace = trace_values
174-
.into_iter()
175-
.map(|eval| {
176-
// TODO(Ohad): Support non-power of 2 inputs.
177-
let domain = CanonicCoset::new(
178-
eval.len()
179-
.checked_ilog2()
180-
.expect("Input is not a power of 2!"),
181-
)
182-
.circle_domain();
183-
CircleEvaluation::<SimdBackend, M31, BitReversedOrder>::new(domain, eval)
184-
})
185-
.collect_vec();
186-
187-
(trace, sub_components_inputs)
188-
}
151+
) -> (ComponentTrace<N_TRACE_COLUMNS>, LookupData) {
152+
let log_n_packed_rows = inputs.len().ilog2();
153+
let log_size = log_n_packed_rows + LOG_N_LANES;
154+
let (mut trace, mut lookup_data) = unsafe {
155+
(
156+
ComponentTrace::<N_TRACE_COLUMNS>::uninitialized(log_size),
157+
LookupData::uninitialized(log_n_packed_rows),
158+
)
159+
};
160+
161+
trace
162+
.par_iter_mut()
163+
.zip(inputs.par_iter())
164+
.zip(lookup_data.par_iter_mut())
165+
.for_each(|((row, input), lookup_data)| {
166+
// Initial state
167+
*row[0] = input.pc;
168+
*row[1] = input.ap;
169+
*row[2] = input.fp;
170+
*lookup_data.state[0] = [input.pc, input.ap, input.fp];
171+
172+
// Flags
173+
let [opcode, off0, off1, off2] = memory_trace_generator
174+
.deduce_output(input.pc)
175+
.into_packed_m31s();
176+
*lookup_data.memory[0] = [input.pc, opcode, off0, off1, off2];
177+
178+
let [op_type, reg0, reg1, reg2, appp] =
179+
decode_opcode(INSTRUCTION_BASE, opcode, [2, 2, 2, 2, 2]);
180+
181+
*row[3] = op_type;
182+
*row[4] = reg0;
183+
*row[5] = reg1;
184+
*row[6] = reg2;
185+
*row[7] = appp;
186+
187+
// Offsets
188+
*row[8] = off0;
189+
*row[9] = off1;
190+
*row[10] = off2;
191+
192+
// Addresses
193+
let dst_addr = Selector::select(&reg0, [&input.ap, &input.fp]) + off0;
194+
let lhs_addr = Selector::select(&reg1, [&input.ap, &input.fp]) + off1;
195+
let rhs_addr = Selector::select(&reg2, [&input.ap, &input.fp]) + off2;
196+
197+
*row[11] = dst_addr;
198+
*row[12] = lhs_addr;
199+
*row[13] = rhs_addr;
200+
201+
// Values
202+
let [dst0, dst1, dst2, dst3] = memory_trace_generator
203+
.deduce_output(dst_addr)
204+
.into_packed_m31s();
205+
let [lhs0, lhs1, lhs2, lhs3] = memory_trace_generator
206+
.deduce_output(lhs_addr)
207+
.into_packed_m31s();
208+
let [rhs0, rhs1, rhs2, rhs3] = memory_trace_generator
209+
.deduce_output(rhs_addr)
210+
.into_packed_m31s();
211+
212+
*row[14] = dst0;
213+
*row[15] = dst1;
214+
*row[16] = dst2;
215+
*row[17] = dst3;
216+
217+
*row[18] = lhs0;
218+
*row[19] = lhs1;
219+
*row[20] = lhs2;
220+
*row[21] = lhs3;
221+
222+
*row[22] = rhs0;
223+
*row[23] = rhs1;
224+
*row[24] = rhs2;
225+
*row[25] = rhs3;
226+
227+
*lookup_data.memory[1] = [dst_addr, dst0, dst1, dst2, dst3];
228+
*lookup_data.memory[2] = [lhs_addr, lhs0, lhs1, lhs2, lhs3];
229+
*lookup_data.memory[3] = [rhs_addr, rhs0, rhs1, rhs2, rhs3];
230+
231+
*lookup_data.state[1] = [input.pc + PackedM31::one(), input.ap + appp, input.fp];
232+
});
189233

190-
// Add / Mul trace row:
191-
// | State (3) | flags (5) | offsets (3) | addrs (3) | values (3 * 4) |
192-
fn write_trace_row(
193-
trace: &mut [Col<SimdBackend, M31>],
194-
input: &PackedVmState,
195-
row_index: usize,
196-
interaction_claim_generator: &mut InteractionClaimGenerator,
197-
memory_trace_generator: &memory::ClaimGenerator,
198-
) {
199-
// Initial state
200-
trace[0].data[row_index] = input.pc;
201-
trace[1].data[row_index] = input.ap;
202-
trace[2].data[row_index] = input.fp;
203-
interaction_claim_generator.state[0].push([input.pc, input.ap, input.fp]);
204-
205-
// Flags
206-
let [opcode, off0, off1, off2] = memory_trace_generator
207-
.deduce_output(input.pc)
208-
.into_packed_m31s();
209-
interaction_claim_generator.memory[0].push([input.pc, opcode, off0, off1, off2]);
210-
211-
let [op_type, reg0, reg1, reg2, appp] =
212-
decode_opcode(INSTRUCTION_BASE, opcode, [2, 2, 2, 2, 2]);
213-
214-
trace[3].data[row_index] = op_type;
215-
trace[4].data[row_index] = reg0;
216-
trace[5].data[row_index] = reg1;
217-
trace[6].data[row_index] = reg2;
218-
trace[7].data[row_index] = appp;
219-
220-
// Offsets
221-
trace[8].data[row_index] = off0;
222-
trace[9].data[row_index] = off1;
223-
trace[10].data[row_index] = off2;
224-
225-
// Addresses
226-
let dst_addr = Selector::select(&reg0, [&input.ap, &input.fp]) + off0;
227-
let lhs_addr = Selector::select(&reg1, [&input.ap, &input.fp]) + off1;
228-
let rhs_addr = Selector::select(&reg2, [&input.ap, &input.fp]) + off2;
229-
230-
trace[11].data[row_index] = dst_addr;
231-
trace[12].data[row_index] = lhs_addr;
232-
trace[13].data[row_index] = rhs_addr;
233-
234-
// Values
235-
let [dst0, dst1, dst2, dst3] = memory_trace_generator
236-
.deduce_output(dst_addr)
237-
.into_packed_m31s();
238-
let [lhs0, lhs1, lhs2, lhs3] = memory_trace_generator
239-
.deduce_output(lhs_addr)
240-
.into_packed_m31s();
241-
let [rhs0, rhs1, rhs2, rhs3] = memory_trace_generator
242-
.deduce_output(rhs_addr)
243-
.into_packed_m31s();
244-
245-
trace[14].data[row_index] = dst0;
246-
trace[15].data[row_index] = dst1;
247-
trace[16].data[row_index] = dst2;
248-
trace[17].data[row_index] = dst3;
249-
250-
trace[18].data[row_index] = lhs0;
251-
trace[19].data[row_index] = lhs1;
252-
trace[20].data[row_index] = lhs2;
253-
trace[21].data[row_index] = lhs3;
254-
255-
trace[22].data[row_index] = rhs0;
256-
trace[23].data[row_index] = rhs1;
257-
trace[24].data[row_index] = rhs2;
258-
trace[25].data[row_index] = rhs3;
259-
260-
interaction_claim_generator.memory[1].push([dst_addr, dst0, dst1, dst2, dst3]);
261-
interaction_claim_generator.memory[2].push([lhs_addr, lhs0, lhs1, lhs2, lhs3]);
262-
interaction_claim_generator.memory[3].push([rhs_addr, rhs0, rhs1, rhs2, rhs3]);
263-
264-
interaction_claim_generator.state[1].push([
265-
input.pc + PackedM31::one(),
266-
input.ap + appp,
267-
input.fp,
268-
]);
234+
(trace, lookup_data)
269235
}

0 commit comments

Comments
 (0)