|
| 1 | +use itertools::{zip_eq, Itertools}; |
| 2 | +use num_traits::One; |
| 3 | +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; |
| 4 | +use stwo_air_utils::trace::component_trace::ComponentTrace; |
| 5 | +use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized}; |
| 6 | +use stwo_prover::constraint_framework::logup::LogupTraceGenerator; |
| 7 | +use stwo_prover::constraint_framework::Relation; |
| 8 | +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; |
| 9 | +use stwo_prover::core::backend::simd::qm31::PackedQM31; |
| 10 | +use stwo_prover::core::backend::simd::SimdBackend; |
| 11 | +use stwo_prover::core::fields::m31::M31; |
| 12 | +use stwo_prover::core::pcs::TreeBuilder; |
| 13 | +use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order; |
| 14 | +use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; |
| 15 | + |
| 16 | +use super::component::{Claim, InteractionClaim}; |
| 17 | +use crate::components::addap_jmpabs_jmprel_opcode::component::INSTRUCTION_BASE; |
| 18 | +use crate::components::memory; |
| 19 | +use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE}; |
| 20 | +use crate::utils::prover::decode_opcode; |
| 21 | +use crate::utils::types::{CasmState, PackedCasmState}; |
| 22 | +use crate::utils::{Selector, SelectorTrait}; |
| 23 | + |
| 24 | +const N_TRACE_COLUMNS: usize = 7; |
| 25 | +const N_MEMORY_LOOKUPS: usize = 1; |
| 26 | +const N_STATE_LOOKUPS: usize = 2; |
| 27 | + |
| 28 | +#[derive(Debug)] |
| 29 | +pub struct ClaimGenerator { |
| 30 | + pub inputs: Vec<PackedCasmState>, |
| 31 | +} |
| 32 | +impl ClaimGenerator { |
| 33 | + pub fn new(mut inputs: Vec<CasmState>) -> Self { |
| 34 | + assert!(!inputs.is_empty()); |
| 35 | + |
| 36 | + // TODO(spapini): Split to multiple components. |
| 37 | + let size = inputs.len().next_power_of_two(); |
| 38 | + inputs.resize(size, inputs[0]); |
| 39 | + |
| 40 | + let inputs = inputs |
| 41 | + .into_iter() |
| 42 | + .array_chunks::<N_LANES>() |
| 43 | + .map(|chunk| PackedCasmState { |
| 44 | + pc: PackedM31::from_array(std::array::from_fn(|i| { |
| 45 | + M31::from_u32_unchecked(chunk[i].pc) |
| 46 | + })), |
| 47 | + ap: PackedM31::from_array(std::array::from_fn(|i| { |
| 48 | + M31::from_u32_unchecked(chunk[i].ap) |
| 49 | + })), |
| 50 | + fp: PackedM31::from_array(std::array::from_fn(|i| { |
| 51 | + M31::from_u32_unchecked(chunk[i].fp) |
| 52 | + })), |
| 53 | + }) |
| 54 | + .collect_vec(); |
| 55 | + Self { inputs } |
| 56 | + } |
| 57 | + |
| 58 | + pub fn write_trace( |
| 59 | + mut self, |
| 60 | + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, |
| 61 | + memory_trace_generator: &mut memory::ClaimGenerator, |
| 62 | + ) -> (Claim, InteractionClaimGenerator) { |
| 63 | + let (trace, lookup_data) = write_trace_simd(&self.inputs, memory_trace_generator); |
| 64 | + |
| 65 | + let n_rows = self.inputs.len(); |
| 66 | + assert_ne!(n_rows, 0); |
| 67 | + let size = std::cmp::max(n_rows.next_power_of_two(), N_LANES); |
| 68 | + let need_padding = n_rows != size; |
| 69 | + |
| 70 | + if need_padding { |
| 71 | + self.inputs |
| 72 | + .resize(size, self.inputs.first().unwrap().clone()); |
| 73 | + bit_reverse_coset_to_circle_domain_order(&mut self.inputs); |
| 74 | + } |
| 75 | + lookup_data.memory.iter().for_each(|c| { |
| 76 | + c.iter() |
| 77 | + .for_each(|v| memory_trace_generator.add_inputs_simd(&v[0])) |
| 78 | + }); |
| 79 | + tree_builder.extend_evals(trace.to_evals()); |
| 80 | + ( |
| 81 | + Claim { n_rows }, |
| 82 | + InteractionClaimGenerator { |
| 83 | + n_rows, |
| 84 | + lookup_data, |
| 85 | + }, |
| 86 | + ) |
| 87 | + } |
| 88 | +} |
| 89 | +#[derive(Debug, Uninitialized, IterMut, ParIterMut)] |
| 90 | +pub struct LookupData { |
| 91 | + pub memory: [Vec<[PackedM31; N_MEMORY_ELEMS]>; N_MEMORY_LOOKUPS], |
| 92 | + pub state: [Vec<[PackedM31; STATE_SIZE]>; N_STATE_LOOKUPS], |
| 93 | +} |
| 94 | + |
| 95 | +#[derive(Debug)] |
| 96 | +pub struct InteractionClaimGenerator { |
| 97 | + pub n_rows: usize, |
| 98 | + pub lookup_data: LookupData, |
| 99 | +} |
| 100 | + |
| 101 | +impl InteractionClaimGenerator { |
| 102 | + pub fn write_interaction_trace( |
| 103 | + &self, |
| 104 | + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, |
| 105 | + memory_relation: &MemoryRelation, |
| 106 | + state_relation: &StateRelation, |
| 107 | + ) -> InteractionClaim { |
| 108 | + let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES); |
| 109 | + let mut logup_gen = LogupTraceGenerator::new(log_size); |
| 110 | + |
| 111 | + let mut col0 = logup_gen.new_col(); |
| 112 | + let state_use = &self.lookup_data.state[0]; |
| 113 | + let read_pc = &self.lookup_data.memory[0]; |
| 114 | + for (i, (x, y)) in zip_eq(state_use, read_pc).enumerate() { |
| 115 | + let denom_x: PackedQM31 = state_relation.combine(x); |
| 116 | + let denom_y: PackedQM31 = memory_relation.combine(y); |
| 117 | + |
| 118 | + col0.write_frac(i, denom_x + denom_y, denom_x * denom_y) |
| 119 | + } |
| 120 | + col0.finalize_col(); |
| 121 | + |
| 122 | + let mut col_gen = logup_gen.new_col(); |
| 123 | + let state_yield = &self.lookup_data.state[1]; |
| 124 | + for (i, values) in state_yield.iter().enumerate() { |
| 125 | + let denom: PackedQM31 = state_relation.combine(values); |
| 126 | + col_gen.write_frac(i, -PackedQM31::one(), denom); |
| 127 | + } |
| 128 | + col_gen.finalize_col(); |
| 129 | + |
| 130 | + let (trace, total_sum, claimed_sum) = if self.n_rows == 1 << log_size { |
| 131 | + let (trace, claimed_sum) = logup_gen.finalize_last(); |
| 132 | + (trace, claimed_sum, None) |
| 133 | + } else { |
| 134 | + let (trace, [total_sum, claimed_sum]) = |
| 135 | + logup_gen.finalize_at([(1 << log_size) - 1, self.n_rows - 1]); |
| 136 | + (trace, total_sum, Some((claimed_sum, self.n_rows - 1))) |
| 137 | + }; |
| 138 | + tree_builder.extend_evals(trace); |
| 139 | + |
| 140 | + InteractionClaim { |
| 141 | + log_size, |
| 142 | + logup_sums: (total_sum, claimed_sum), |
| 143 | + } |
| 144 | + } |
| 145 | +} |
| 146 | + |
| 147 | +// add_ap_ trace row: |
| 148 | +// pc | ap | fp | trit (addap, jmp_abs, jmp_rel) | imm | res_pc | res_ap |
| 149 | +fn write_trace_simd( |
| 150 | + inputs: &[PackedCasmState], |
| 151 | + memory_trace_generator: &memory::ClaimGenerator, |
| 152 | +) -> (ComponentTrace<N_TRACE_COLUMNS>, LookupData) { |
| 153 | + let log_n_packed_rows = inputs.len().ilog2(); |
| 154 | + let log_size = log_n_packed_rows + LOG_N_LANES; |
| 155 | + let (mut trace, mut lookup_data) = unsafe { |
| 156 | + ( |
| 157 | + ComponentTrace::<N_TRACE_COLUMNS>::uninitialized(log_size), |
| 158 | + LookupData::uninitialized(log_n_packed_rows), |
| 159 | + ) |
| 160 | + }; |
| 161 | + |
| 162 | + trace |
| 163 | + .par_iter_mut() |
| 164 | + .zip(inputs.par_iter()) |
| 165 | + .zip(lookup_data.par_iter_mut()) |
| 166 | + .for_each(|((row, opcode_input), lookup_data)| { |
| 167 | + // Initial state. |
| 168 | + let pc = opcode_input.pc; |
| 169 | + let ap = opcode_input.ap; |
| 170 | + let fp = opcode_input.fp; |
| 171 | + *row[0] = pc; |
| 172 | + *row[1] = ap; |
| 173 | + *row[2] = fp; |
| 174 | + *lookup_data.state[0] = [pc, ap, fp]; |
| 175 | + |
| 176 | + // Decode insturction. |
| 177 | + let [opcode, imm, off0, off1] = |
| 178 | + memory_trace_generator.deduce_output(pc).into_packed_m31s(); |
| 179 | + *lookup_data.memory[0] = [pc, opcode, imm, off0, off1]; |
| 180 | + |
| 181 | + let [op_type] = decode_opcode(INSTRUCTION_BASE, opcode, [3]); |
| 182 | + |
| 183 | + *row[3] = op_type; |
| 184 | + *row[4] = imm; |
| 185 | + |
| 186 | + // Calc new state. |
| 187 | + let new_pc = Selector::select( |
| 188 | + &op_type, |
| 189 | + [&(pc + PackedM31::broadcast(M31::one())), &imm, &(pc + imm)], |
| 190 | + ); |
| 191 | + let new_ap = Selector::select(&op_type, [&(ap + imm), &ap, &ap]); |
| 192 | + |
| 193 | + *row[5] = new_pc; |
| 194 | + *row[6] = new_ap; |
| 195 | + |
| 196 | + *lookup_data.state[1] = [new_pc, new_ap, fp]; |
| 197 | + }); |
| 198 | + |
| 199 | + (trace, lookup_data) |
| 200 | +} |
0 commit comments