|
| 1 | +use itertools::{zip_eq, Itertools}; |
| 2 | +use num_traits::One; |
| 3 | +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; |
| 4 | +use stwo_air_utils::trace::component_trace::ComponentTrace; |
| 5 | +use stwo_air_utils_derive::{IterMut, ParIterMut, Uninitialized}; |
| 6 | +use stwo_prover::constraint_framework::logup::LogupTraceGenerator; |
| 7 | +use stwo_prover::constraint_framework::Relation; |
| 8 | +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; |
| 9 | +use stwo_prover::core::backend::simd::qm31::PackedQM31; |
| 10 | +use stwo_prover::core::backend::simd::SimdBackend; |
| 11 | +use stwo_prover::core::fields::m31::M31; |
| 12 | +use stwo_prover::core::pcs::TreeBuilder; |
| 13 | +use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order; |
| 14 | +use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; |
| 15 | + |
| 16 | +use super::component::{Claim, InteractionClaim, INSTRUCTION_BASE}; |
| 17 | +use crate::components::add_mul_opcode::component::N_TRACE_COLUMNS; |
| 18 | +use crate::components::memory; |
| 19 | +use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE}; |
| 20 | +use crate::utils::prover::decode_opcode; |
| 21 | +use crate::utils::types::{CasmState, PackedCasmState}; |
| 22 | +use crate::utils::{Selector, SelectorTrait}; |
| 23 | + |
| 24 | +const N_MEMORY_LOOKUPS: usize = 3; |
| 25 | +const N_STATE_LOOKUPS: usize = 2; |
| 26 | + |
| 27 | +#[derive(Debug)] |
| 28 | +pub struct ClaimGenerator { |
| 29 | + pub inputs: Vec<PackedCasmState>, |
| 30 | +} |
| 31 | +impl ClaimGenerator { |
| 32 | + pub fn new(mut inputs: Vec<CasmState>) -> Self { |
| 33 | + assert!(!inputs.is_empty()); |
| 34 | + |
| 35 | + // TODO(spapini): Split to multiple components. |
| 36 | + let size = inputs.len().next_power_of_two(); |
| 37 | + inputs.resize(size, inputs[0]); |
| 38 | + |
| 39 | + let inputs = inputs |
| 40 | + .into_iter() |
| 41 | + .array_chunks::<N_LANES>() |
| 42 | + .map(|chunk| PackedCasmState { |
| 43 | + pc: PackedM31::from_array(std::array::from_fn(|i| { |
| 44 | + M31::from_u32_unchecked(chunk[i].pc) |
| 45 | + })), |
| 46 | + ap: PackedM31::from_array(std::array::from_fn(|i| { |
| 47 | + M31::from_u32_unchecked(chunk[i].ap) |
| 48 | + })), |
| 49 | + fp: PackedM31::from_array(std::array::from_fn(|i| { |
| 50 | + M31::from_u32_unchecked(chunk[i].fp) |
| 51 | + })), |
| 52 | + }) |
| 53 | + .collect_vec(); |
| 54 | + Self { inputs } |
| 55 | + } |
| 56 | + |
| 57 | + pub fn write_trace( |
| 58 | + mut self, |
| 59 | + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, |
| 60 | + memory_trace_generator: &mut memory::ClaimGenerator, |
| 61 | + ) -> (Claim, InteractionClaimGenerator) { |
| 62 | + let (trace, lookup_data) = write_trace_simd(&self.inputs, memory_trace_generator); |
| 63 | + |
| 64 | + let n_rows = self.inputs.len(); |
| 65 | + assert_ne!(n_rows, 0); |
| 66 | + let size = std::cmp::max(n_rows.next_power_of_two(), N_LANES); |
| 67 | + let need_padding = n_rows != size; |
| 68 | + |
| 69 | + if need_padding { |
| 70 | + self.inputs |
| 71 | + .resize(size, self.inputs.first().unwrap().clone()); |
| 72 | + bit_reverse_coset_to_circle_domain_order(&mut self.inputs); |
| 73 | + } |
| 74 | + lookup_data.memory.iter().for_each(|c| { |
| 75 | + c.iter() |
| 76 | + .for_each(|v| memory_trace_generator.add_inputs_simd(&v[0])) |
| 77 | + }); |
| 78 | + tree_builder.extend_evals(trace.to_evals()); |
| 79 | + ( |
| 80 | + Claim { n_rows }, |
| 81 | + InteractionClaimGenerator { |
| 82 | + n_rows, |
| 83 | + lookup_data, |
| 84 | + }, |
| 85 | + ) |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +#[derive(Debug, Uninitialized, IterMut, ParIterMut)] |
| 90 | +pub struct LookupData { |
| 91 | + pub memory: [Vec<[PackedM31; N_MEMORY_ELEMS]>; N_MEMORY_LOOKUPS], |
| 92 | + pub state: [Vec<[PackedM31; STATE_SIZE]>; N_STATE_LOOKUPS], |
| 93 | +} |
| 94 | + |
| 95 | +pub struct InteractionClaimGenerator { |
| 96 | + pub n_rows: usize, |
| 97 | + pub lookup_data: LookupData, |
| 98 | +} |
| 99 | + |
| 100 | +impl InteractionClaimGenerator { |
| 101 | + pub fn with_capacity(capacity: usize) -> Self { |
| 102 | + Self { |
| 103 | + n_rows: capacity, |
| 104 | + memory: [ |
| 105 | + Vec::with_capacity(capacity), |
| 106 | + Vec::with_capacity(capacity), |
| 107 | + Vec::with_capacity(capacity), |
| 108 | + Vec::with_capacity(capacity), |
| 109 | + ], |
| 110 | + state: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)], |
| 111 | + } |
| 112 | + } |
| 113 | + |
| 114 | + pub fn write_interaction_trace( |
| 115 | + &self, |
| 116 | + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, |
| 117 | + memory_relation: &MemoryRelation, |
| 118 | + state_relation: &StateRelation, |
| 119 | + ) -> InteractionClaim { |
| 120 | + let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES); |
| 121 | + let mut logup_gen = LogupTraceGenerator::new(log_size); |
| 122 | + |
| 123 | + let mut col0 = logup_gen.new_col(); |
| 124 | + let state_use = &self.lookup_data.state[0]; |
| 125 | + let read_pc = &self.lookup_data.memory[0]; |
| 126 | + for (i, (x, y)) in zip_eq(state_use, read_pc).enumerate() { |
| 127 | + let denom_x: PackedQM31 = state_relation.combine(x); |
| 128 | + let denom_y: PackedQM31 = memory_relation.combine(y); |
| 129 | + |
| 130 | + col0.write_frac(i, denom_x + denom_y, denom_x * denom_y) |
| 131 | + } |
| 132 | + col0.finalize_col(); |
| 133 | + |
| 134 | + let mut col_gen = logup_gen.new_col(); |
| 135 | + let state_yield = &self.lookup_data.state[1]; |
| 136 | + for (i, values) in state_yield.iter().enumerate() { |
| 137 | + let denom: PackedQM31 = state_relation.combine(values); |
| 138 | + col_gen.write_frac(i, -PackedQM31::one(), denom); |
| 139 | + } |
| 140 | + col_gen.finalize_col(); |
| 141 | + |
| 142 | + let (trace, total_sum, claimed_sum) = if self.n_rows == 1 << log_size { |
| 143 | + let (trace, claimed_sum) = logup_gen.finalize_last(); |
| 144 | + (trace, claimed_sum, None) |
| 145 | + } else { |
| 146 | + let (trace, [total_sum, claimed_sum]) = |
| 147 | + logup_gen.finalize_at([(1 << log_size) - 1, self.n_rows - 1]); |
| 148 | + (trace, total_sum, Some((claimed_sum, self.n_rows - 1))) |
| 149 | + }; |
| 150 | + tree_builder.extend_evals(trace); |
| 151 | + |
| 152 | + InteractionClaim { |
| 153 | + log_size, |
| 154 | + logup_sums: (total_sum, claimed_sum), |
| 155 | + } |
| 156 | + } |
| 157 | +} |
| 158 | + |
| 159 | +// Add / Mul trace row: |
| 160 | +// | State (3) | flags (4) | offsets (2) | imm (1) | addrs (2) | values (2 * 4) | |
| 161 | +fn write_trace_simd( |
| 162 | + inputs: &[PackedCasmState], |
| 163 | + memory_trace_generator: &memory::ClaimGenerator, |
| 164 | +) -> (ComponentTrace<N_TRACE_COLUMNS>, LookupData) { |
| 165 | + let log_n_packed_rows = inputs.len().ilog2(); |
| 166 | + let log_size = log_n_packed_rows + LOG_N_LANES; |
| 167 | + let (mut trace, mut lookup_data) = unsafe { |
| 168 | + ( |
| 169 | + ComponentTrace::<N_TRACE_COLUMNS>::uninitialized(log_size), |
| 170 | + LookupData::uninitialized(log_n_packed_rows), |
| 171 | + ) |
| 172 | + }; |
| 173 | + |
| 174 | + trace |
| 175 | + .par_iter_mut() |
| 176 | + .zip(inputs.par_iter()) |
| 177 | + .zip(lookup_data.par_iter_mut()) |
| 178 | + .for_each(|((row, opcode_input), lookup_data)| { |
| 179 | + // Initial state. |
| 180 | + let pc = opcode_input.pc; |
| 181 | + let ap = opcode_input.ap; |
| 182 | + let fp = opcode_input.fp; |
| 183 | + *row[0] = pc; |
| 184 | + *row[1] = ap; |
| 185 | + *row[2] = fp; |
| 186 | + *lookup_data.state[0] = [pc, ap, fp]; |
| 187 | + |
| 188 | + // Decode insturction. |
| 189 | + let [opcode, off0, off1, imm] = |
| 190 | + memory_trace_generator.deduce_output(pc).into_packed_m31s(); |
| 191 | + *lookup_data.memory[0] = [pc, opcode, off0, off1, imm]; |
| 192 | + |
| 193 | + let [op_type, lhs_flag, rhs_flag, appp] = |
| 194 | + decode_opcode(INSTRUCTION_BASE, opcode, [2, 2, 2, 2]); |
| 195 | + *row[3] = op_type; |
| 196 | + *row[4] = lhs_flag; |
| 197 | + *row[5] = rhs_flag; |
| 198 | + *row[6] = appp; |
| 199 | + |
| 200 | + // Offsets |
| 201 | + *row[8] = off0; |
| 202 | + *row[9] = off1; |
| 203 | + *row[10] = imm; |
| 204 | + |
| 205 | + // Addresses |
| 206 | + let lhs_addr = Selector::select(&lhs_flag, [&ap, &fp]) + off0; |
| 207 | + let rhs_addr = Selector::select(&rhs_flag, [&ap, &fp]) + off1; |
| 208 | + |
| 209 | + *row[11] = lhs_addr; |
| 210 | + *row[12] = rhs_addr; |
| 211 | + |
| 212 | + let [lhs0, lhs1, lhs2, lhs3] = memory_trace_generator |
| 213 | + .deduce_output(lhs_addr) |
| 214 | + .into_packed_m31s(); |
| 215 | + let [rhs0, rhs1, rhs2, rhs3] = memory_trace_generator |
| 216 | + .deduce_output(rhs_addr) |
| 217 | + .into_packed_m31s(); |
| 218 | + |
| 219 | + *row[13] = lhs0; |
| 220 | + *row[14] = lhs1; |
| 221 | + *row[15] = lhs2; |
| 222 | + *row[16] = lhs3; |
| 223 | + |
| 224 | + *row[17] = rhs0; |
| 225 | + *row[18] = rhs1; |
| 226 | + *row[18] = rhs2; |
| 227 | + *row[20] = rhs3; |
| 228 | + |
| 229 | + *lookup_data.memory[0] = [pc, opcode, imm, off0, off1]; |
| 230 | + *lookup_data.memory[1] = [lhs_addr, lhs0, lhs1, lhs2, lhs3]; |
| 231 | + *lookup_data.memory[2] = [rhs_addr, rhs0, rhs1, rhs2, rhs3]; |
| 232 | + |
| 233 | + *lookup_data.state[1] = [pc + PackedM31::one(), ap + appp, fp]; |
| 234 | + }); |
| 235 | + |
| 236 | + (trace, lookup_data) |
| 237 | +} |
0 commit comments