11use itertools:: { zip_eq, Itertools } ;
22use num_traits:: One ;
3+ use rayon:: iter:: { IndexedParallelIterator , IntoParallelRefIterator , ParallelIterator } ;
4+ use stwo_air_utils:: trace:: component_trace:: ComponentTrace ;
5+ use stwo_air_utils_derive:: { IterMut , ParIterMut , Uninitialized } ;
36use stwo_prover:: constraint_framework:: logup:: LogupTraceGenerator ;
47use stwo_prover:: constraint_framework:: Relation ;
58use stwo_prover:: core:: backend:: simd:: m31:: { PackedM31 , LOG_N_LANES , N_LANES } ;
69use stwo_prover:: core:: backend:: simd:: qm31:: PackedQM31 ;
710use stwo_prover:: core:: backend:: simd:: SimdBackend ;
8- use stwo_prover:: core:: backend:: { Col , Column } ;
911use stwo_prover:: core:: fields:: m31:: M31 ;
1012use stwo_prover:: core:: pcs:: TreeBuilder ;
11- use stwo_prover:: core:: poly:: circle:: { CanonicCoset , CircleEvaluation } ;
12- use stwo_prover:: core:: poly:: BitReversedOrder ;
1313use stwo_prover:: core:: vcs:: blake2_merkle:: Blake2sMerkleChannel ;
1414
1515use super :: component:: { Claim , InteractionClaim , INSTRUCTION_BASE } ;
@@ -63,51 +63,48 @@ impl ClaimGenerator {
6363 tree_builder : & mut TreeBuilder < ' _ , ' _ , SimdBackend , Blake2sMerkleChannel > ,
6464 memory_trace_generator : & mut memory:: ClaimGenerator ,
6565 ) -> ( Claim , InteractionClaimGenerator ) {
66- let ( trace, interaction_claim_generator) =
67- write_trace_simd ( & self . inputs , memory_trace_generator) ;
68- interaction_claim_generator. memory . iter ( ) . for_each ( |c| {
66+ let ( trace, lookup_data) = write_trace_simd ( & self . inputs , memory_trace_generator) ;
67+ lookup_data. memory . iter ( ) . for_each ( |c| {
6968 c. iter ( )
7069 . for_each ( |v| memory_trace_generator. add_inputs_simd ( & v[ 0 ] ) )
7170 } ) ;
72- tree_builder. extend_evals ( trace) ;
73- let claim = Claim {
74- n_rows : self . inputs . len ( ) * N_LANES ,
75- } ;
76- ( claim, interaction_claim_generator)
71+ tree_builder. extend_evals ( trace. to_evals ( ) ) ;
72+ let n_rows = self . inputs . len ( ) * N_LANES ;
73+ (
74+ Claim { n_rows } ,
75+ InteractionClaimGenerator {
76+ n_rows,
77+ lookup_data,
78+ } ,
79+ )
7780 }
7881}
7982
80- pub struct InteractionClaimGenerator {
81- pub n_rows : usize ,
83+ # [ derive ( Debug , Uninitialized , IterMut , ParIterMut ) ]
84+ pub struct LookupData {
8285 pub memory : [ Vec < [ PackedM31 ; N_MEMORY_ELEMS ] > ; N_MEMORY_LOOKUPS ] ,
8386 pub state : [ Vec < [ PackedM31 ; STATE_SIZE ] > ; N_STATE_LOOKUPS ] ,
8487}
85- impl InteractionClaimGenerator {
86- pub fn with_capacity ( capacity : usize ) -> Self {
87- Self {
88- n_rows : capacity,
89- memory : [
90- Vec :: with_capacity ( capacity) ,
91- Vec :: with_capacity ( capacity) ,
92- Vec :: with_capacity ( capacity) ,
93- Vec :: with_capacity ( capacity) ,
94- ] ,
95- state : [ Vec :: with_capacity ( capacity) , Vec :: with_capacity ( capacity) ] ,
96- }
97- }
9888
89+ #[ derive( Debug ) ]
90+ pub struct InteractionClaimGenerator {
91+ pub n_rows : usize ,
92+ pub lookup_data : LookupData ,
93+ }
94+
95+ impl InteractionClaimGenerator {
9996 pub fn write_interaction_trace (
10097 & self ,
10198 tree_builder : & mut TreeBuilder < ' _ , ' _ , SimdBackend , Blake2sMerkleChannel > ,
10299 memory_relation : & MemoryRelation ,
103100 state_relation : & StateRelation ,
104101 ) -> InteractionClaim {
105- let log_size = self . memory [ 0 ] . len ( ) . ilog2 ( ) + LOG_N_LANES ;
102+ let log_size = std :: cmp :: max ( self . n_rows . next_power_of_two ( ) . ilog2 ( ) , LOG_N_LANES ) ;
106103 let mut logup_gen = LogupTraceGenerator :: new ( log_size) ;
107104
108105 let mut col0 = logup_gen. new_col ( ) ;
109- let state_use = & self . state [ 0 ] ;
110- let read_pc = & self . memory [ 0 ] ;
106+ let state_use = & self . lookup_data . state [ 0 ] ;
107+ let read_pc = & self . lookup_data . memory [ 0 ] ;
111108 for ( i, ( x, y) ) in zip_eq ( state_use, read_pc) . enumerate ( ) {
112109 let denom_x: PackedQM31 = state_relation. combine ( x) ;
113110 let denom_y: PackedQM31 = memory_relation. combine ( y) ;
@@ -117,8 +114,8 @@ impl InteractionClaimGenerator {
117114 col0. finalize_col ( ) ;
118115
119116 let mut col1 = logup_gen. new_col ( ) ;
120- let read_dst = & self . memory [ 1 ] ;
121- let read_lhs = & self . memory [ 2 ] ;
117+ let read_dst = & self . lookup_data . memory [ 1 ] ;
118+ let read_lhs = & self . lookup_data . memory [ 2 ] ;
122119 for ( i, ( x, y) ) in zip_eq ( read_dst, read_lhs) . enumerate ( ) {
123120 let denom_x: PackedQM31 = memory_relation. combine ( x) ;
124121 let denom_y: PackedQM31 = memory_relation. combine ( y) ;
@@ -128,8 +125,8 @@ impl InteractionClaimGenerator {
128125 col1. finalize_col ( ) ;
129126
130127 let mut col2 = logup_gen. new_col ( ) ;
131- let read_rhs = & self . memory [ 3 ] ;
132- let state_yield = & self . state [ 1 ] ;
128+ let read_rhs = & self . lookup_data . memory [ 3 ] ;
129+ let state_yield = & self . lookup_data . state [ 1 ] ;
133130 for ( i, ( x, y) ) in zip_eq ( read_rhs, state_yield) . enumerate ( ) {
134131 let denom_x: PackedQM31 = memory_relation. combine ( x) ;
135132 let denom_y: PackedQM31 = state_relation. combine ( y) ;
@@ -151,119 +148,88 @@ impl InteractionClaimGenerator {
151148fn write_trace_simd (
152149 inputs : & [ PackedVmState ] ,
153150 memory_trace_generator : & memory:: ClaimGenerator ,
154- ) -> (
155- Vec < CircleEvaluation < SimdBackend , M31 , BitReversedOrder > > ,
156- InteractionClaimGenerator ,
157- ) {
158- let n_trace_columns = N_TRACE_COLUMNS ;
159- let mut trace_values = ( 0 ..n_trace_columns)
160- . map ( |_| Col :: < SimdBackend , M31 > :: zeros ( inputs. len ( ) * N_LANES ) )
161- . collect_vec ( ) ;
162- let mut sub_components_inputs = InteractionClaimGenerator :: with_capacity ( inputs. len ( ) ) ;
163- inputs. iter ( ) . enumerate ( ) . for_each ( |( i, input) | {
164- write_trace_row (
165- & mut trace_values,
166- input,
167- i,
168- & mut sub_components_inputs,
169- memory_trace_generator,
170- ) ;
171- } ) ;
172-
173- let trace = trace_values
174- . into_iter ( )
175- . map ( |eval| {
176- // TODO(Ohad): Support non-power of 2 inputs.
177- let domain = CanonicCoset :: new (
178- eval. len ( )
179- . checked_ilog2 ( )
180- . expect ( "Input is not a power of 2!" ) ,
181- )
182- . circle_domain ( ) ;
183- CircleEvaluation :: < SimdBackend , M31 , BitReversedOrder > :: new ( domain, eval)
184- } )
185- . collect_vec ( ) ;
186-
187- ( trace, sub_components_inputs)
188- }
151+ ) -> ( ComponentTrace < N_TRACE_COLUMNS > , LookupData ) {
152+ let log_n_packed_rows = inputs. len ( ) . ilog2 ( ) ;
153+ let log_size = log_n_packed_rows + LOG_N_LANES ;
154+ let ( mut trace, mut lookup_data) = unsafe {
155+ (
156+ ComponentTrace :: < N_TRACE_COLUMNS > :: uninitialized ( log_size) ,
157+ LookupData :: uninitialized ( log_n_packed_rows) ,
158+ )
159+ } ;
160+
161+ trace
162+ . par_iter_mut ( )
163+ . zip ( inputs. par_iter ( ) )
164+ . zip ( lookup_data. par_iter_mut ( ) )
165+ . for_each ( |( ( row, input) , lookup_data) | {
166+ // Initial state
167+ * row[ 0 ] = input. pc ;
168+ * row[ 1 ] = input. ap ;
169+ * row[ 2 ] = input. fp ;
170+ * lookup_data. state [ 0 ] = [ input. pc , input. ap , input. fp ] ;
171+
172+ // Flags
173+ let [ opcode, off0, off1, off2] = memory_trace_generator
174+ . deduce_output ( input. pc )
175+ . into_packed_m31s ( ) ;
176+ * lookup_data. memory [ 0 ] = [ input. pc , opcode, off0, off1, off2] ;
177+
178+ let [ op_type, reg0, reg1, reg2, appp] =
179+ decode_opcode ( INSTRUCTION_BASE , opcode, [ 2 , 2 , 2 , 2 , 2 ] ) ;
180+
181+ * row[ 3 ] = op_type;
182+ * row[ 4 ] = reg0;
183+ * row[ 5 ] = reg1;
184+ * row[ 6 ] = reg2;
185+ * row[ 7 ] = appp;
186+
187+ // Offsets
188+ * row[ 8 ] = off0;
189+ * row[ 9 ] = off1;
190+ * row[ 10 ] = off2;
191+
192+ // Addresses
193+ let dst_addr = Selector :: select ( & reg0, [ & input. ap , & input. fp ] ) + off0;
194+ let lhs_addr = Selector :: select ( & reg1, [ & input. ap , & input. fp ] ) + off1;
195+ let rhs_addr = Selector :: select ( & reg2, [ & input. ap , & input. fp ] ) + off2;
196+
197+ * row[ 11 ] = dst_addr;
198+ * row[ 12 ] = lhs_addr;
199+ * row[ 13 ] = rhs_addr;
200+
201+ // Values
202+ let [ dst0, dst1, dst2, dst3] = memory_trace_generator
203+ . deduce_output ( dst_addr)
204+ . into_packed_m31s ( ) ;
205+ let [ lhs0, lhs1, lhs2, lhs3] = memory_trace_generator
206+ . deduce_output ( lhs_addr)
207+ . into_packed_m31s ( ) ;
208+ let [ rhs0, rhs1, rhs2, rhs3] = memory_trace_generator
209+ . deduce_output ( rhs_addr)
210+ . into_packed_m31s ( ) ;
211+
212+ * row[ 14 ] = dst0;
213+ * row[ 15 ] = dst1;
214+ * row[ 16 ] = dst2;
215+ * row[ 17 ] = dst3;
216+
217+ * row[ 18 ] = lhs0;
218+ * row[ 19 ] = lhs1;
219+ * row[ 20 ] = lhs2;
220+ * row[ 21 ] = lhs3;
221+
222+ * row[ 22 ] = rhs0;
223+ * row[ 23 ] = rhs1;
224+ * row[ 24 ] = rhs2;
225+ * row[ 25 ] = rhs3;
226+
227+ * lookup_data. memory [ 1 ] = [ dst_addr, dst0, dst1, dst2, dst3] ;
228+ * lookup_data. memory [ 2 ] = [ lhs_addr, lhs0, lhs1, lhs2, lhs3] ;
229+ * lookup_data. memory [ 3 ] = [ rhs_addr, rhs0, rhs1, rhs2, rhs3] ;
230+
231+ * lookup_data. state [ 1 ] = [ input. pc + PackedM31 :: one ( ) , input. ap + appp, input. fp ] ;
232+ } ) ;
189233
190- // Add / Mul trace row:
191- // | State (3) | flags (5) | offsets (3) | addrs (3) | values (3 * 4) |
192- fn write_trace_row (
193- trace : & mut [ Col < SimdBackend , M31 > ] ,
194- input : & PackedVmState ,
195- row_index : usize ,
196- interaction_claim_generator : & mut InteractionClaimGenerator ,
197- memory_trace_generator : & memory:: ClaimGenerator ,
198- ) {
199- // Initial state
200- trace[ 0 ] . data [ row_index] = input. pc ;
201- trace[ 1 ] . data [ row_index] = input. ap ;
202- trace[ 2 ] . data [ row_index] = input. fp ;
203- interaction_claim_generator. state [ 0 ] . push ( [ input. pc , input. ap , input. fp ] ) ;
204-
205- // Flags
206- let [ opcode, off0, off1, off2] = memory_trace_generator
207- . deduce_output ( input. pc )
208- . into_packed_m31s ( ) ;
209- interaction_claim_generator. memory [ 0 ] . push ( [ input. pc , opcode, off0, off1, off2] ) ;
210-
211- let [ op_type, reg0, reg1, reg2, appp] =
212- decode_opcode ( INSTRUCTION_BASE , opcode, [ 2 , 2 , 2 , 2 , 2 ] ) ;
213-
214- trace[ 3 ] . data [ row_index] = op_type;
215- trace[ 4 ] . data [ row_index] = reg0;
216- trace[ 5 ] . data [ row_index] = reg1;
217- trace[ 6 ] . data [ row_index] = reg2;
218- trace[ 7 ] . data [ row_index] = appp;
219-
220- // Offsets
221- trace[ 8 ] . data [ row_index] = off0;
222- trace[ 9 ] . data [ row_index] = off1;
223- trace[ 10 ] . data [ row_index] = off2;
224-
225- // Addresses
226- let dst_addr = Selector :: select ( & reg0, [ & input. ap , & input. fp ] ) + off0;
227- let lhs_addr = Selector :: select ( & reg1, [ & input. ap , & input. fp ] ) + off1;
228- let rhs_addr = Selector :: select ( & reg2, [ & input. ap , & input. fp ] ) + off2;
229-
230- trace[ 11 ] . data [ row_index] = dst_addr;
231- trace[ 12 ] . data [ row_index] = lhs_addr;
232- trace[ 13 ] . data [ row_index] = rhs_addr;
233-
234- // Values
235- let [ dst0, dst1, dst2, dst3] = memory_trace_generator
236- . deduce_output ( dst_addr)
237- . into_packed_m31s ( ) ;
238- let [ lhs0, lhs1, lhs2, lhs3] = memory_trace_generator
239- . deduce_output ( lhs_addr)
240- . into_packed_m31s ( ) ;
241- let [ rhs0, rhs1, rhs2, rhs3] = memory_trace_generator
242- . deduce_output ( rhs_addr)
243- . into_packed_m31s ( ) ;
244-
245- trace[ 14 ] . data [ row_index] = dst0;
246- trace[ 15 ] . data [ row_index] = dst1;
247- trace[ 16 ] . data [ row_index] = dst2;
248- trace[ 17 ] . data [ row_index] = dst3;
249-
250- trace[ 18 ] . data [ row_index] = lhs0;
251- trace[ 19 ] . data [ row_index] = lhs1;
252- trace[ 20 ] . data [ row_index] = lhs2;
253- trace[ 21 ] . data [ row_index] = lhs3;
254-
255- trace[ 22 ] . data [ row_index] = rhs0;
256- trace[ 23 ] . data [ row_index] = rhs1;
257- trace[ 24 ] . data [ row_index] = rhs2;
258- trace[ 25 ] . data [ row_index] = rhs3;
259-
260- interaction_claim_generator. memory [ 1 ] . push ( [ dst_addr, dst0, dst1, dst2, dst3] ) ;
261- interaction_claim_generator. memory [ 2 ] . push ( [ lhs_addr, lhs0, lhs1, lhs2, lhs3] ) ;
262- interaction_claim_generator. memory [ 3 ] . push ( [ rhs_addr, rhs0, rhs1, rhs2, rhs3] ) ;
263-
264- interaction_claim_generator. state [ 1 ] . push ( [
265- input. pc + PackedM31 :: one ( ) ,
266- input. ap + appp,
267- input. fp ,
268- ] ) ;
234+ ( trace, lookup_data)
269235}
0 commit comments