|
| 1 | +use itertools::{chain, Itertools}; |
| 2 | +use num_traits::One; |
| 3 | +use serde::{Deserialize, Serialize}; |
| 4 | +use stwo_prover::constraint_framework::{ |
| 5 | + EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, |
| 6 | +}; |
| 7 | +use stwo_prover::core::backend::simd::m31::LOG_N_LANES; |
| 8 | +use stwo_prover::core::channel::Channel; |
| 9 | +use stwo_prover::core::fields::m31::M31; |
| 10 | +use stwo_prover::core::fields::qm31::SecureField; |
| 11 | +use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; |
| 12 | +use stwo_prover::core::pcs::TreeVec; |
| 13 | + |
| 14 | +use crate::relations::{MemoryRelation, StateRelation}; |
| 15 | +use crate::utils::component::{decode_opcode, is_bit}; |
| 16 | +use crate::utils::{Selector, SelectorTrait}; |
| 17 | + |
| 18 | +pub const N_TRACE_CELLS: usize = 19; |
| 19 | + |
| 20 | +// Assumes INSTRUCTION_BASE=K such that: |
| 21 | +/// ` |
| 22 | +/// addap_imm = K |
| 23 | +/// jmp_abs_imm = K + 1 |
| 24 | +/// jmp_rel_imm = K + 2 |
| 25 | +/// ` |
| 26 | +// TODO: organize opcodes so that K will work as detailed above, instead of just 0. |
| 27 | +pub const INSTRUCTION_BASE: M31 = M31::from_u32_unchecked(0); |
| 28 | + |
| 29 | +pub type Component = FrameworkComponent<Eval>; |
| 30 | + |
| 31 | +pub struct Eval { |
| 32 | + pub claim: Claim, |
| 33 | + pub memory_lookup: MemoryRelation, |
| 34 | + pub state_lookup: StateRelation, |
| 35 | +} |
| 36 | + |
| 37 | +impl Eval { |
| 38 | + pub fn new(claim: Claim, memory_lookup: MemoryRelation, state_lookup: StateRelation) -> Self { |
| 39 | + Self { |
| 40 | + claim: claim.clone(), |
| 41 | + memory_lookup, |
| 42 | + state_lookup, |
| 43 | + } |
| 44 | + } |
| 45 | +} |
| 46 | +impl FrameworkEval for Eval { |
| 47 | + fn log_size(&self) -> u32 { |
| 48 | + std::cmp::max(self.claim.n_rows.next_power_of_two().ilog2(), LOG_N_LANES) |
| 49 | + } |
| 50 | + |
| 51 | + fn max_constraint_log_degree_bound(&self) -> u32 { |
| 52 | + self.log_size() + 1 |
| 53 | + } |
| 54 | + |
| 55 | + fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E { |
| 56 | + let state = std::array::from_fn(|_| eval.next_trace_mask()); |
| 57 | + // Use initial state. |
| 58 | + eval.add_to_relation(RelationEntry::new(&self.state_lookup, E::EF::one(), &state)); |
| 59 | + let [pc, ap, fp] = state; |
| 60 | + |
| 61 | + // Assert flags are in range. |
| 62 | + let [op_type, reg] = std::array::from_fn(|_| eval.next_trace_mask()); |
| 63 | + eval.add_constraint(is_bit::<E>(&op_type)); |
| 64 | + eval.add_constraint(is_bit::<E>(®)); |
| 65 | + |
| 66 | + // Check instruction. |
| 67 | + let opcode = decode_opcode( |
| 68 | + INSTRUCTION_BASE.into(), |
| 69 | + &[ |
| 70 | + (op_type.clone(), 2), // [jmp abs, jmp rel] |
| 71 | + (reg.clone(), 2), // [ap, fp] |
| 72 | + ], |
| 73 | + ); |
| 74 | + |
| 75 | + let [off, imm] = std::array::from_fn(|_| eval.next_trace_mask()); |
| 76 | + println!("1"); |
| 77 | + eval.add_to_relation(RelationEntry::new( |
| 78 | + &self.memory_lookup, |
| 79 | + E::EF::one(), |
| 80 | + &[ |
| 81 | + pc.clone(), |
| 82 | + opcode.clone(), |
| 83 | + off.clone(), |
| 84 | + imm.clone(), |
| 85 | + ], |
| 86 | + )); |
| 87 | + println!("2"); |
| 88 | + // Compute address. |
| 89 | + let addr = eval.next_trace_mask(); |
| 90 | + eval.add_constraint( |
| 91 | + addr.clone() - (Selector::select(®, [&(ap.clone()), &(fp.clone())]) + off), |
| 92 | + ); |
| 93 | + |
| 94 | + println!("3"); |
| 95 | + let addr_val_arr: [E::F; 4] = std::array::from_fn(|_| eval.next_trace_mask()); |
| 96 | + eval.add_to_relation(RelationEntry::new( |
| 97 | + &self.memory_lookup, |
| 98 | + E::EF::one(), |
| 99 | + &chain!([addr], addr_val_arr.clone()).collect_vec(), |
| 100 | + )); |
| 101 | + let val = E::combine_ef(addr_val_arr); |
| 102 | + |
| 103 | + |
| 104 | + // Check jnz condition. |
| 105 | + let maybe_inverse_val = E::combine_ef(std::array::from_fn(|_| eval.next_trace_mask())); |
| 106 | + println!("4"); |
| 107 | + let flag = eval.next_trace_mask(); |
| 108 | + eval.add_constraint(is_bit::<E>(&flag)); |
| 109 | + |
| 110 | + // flag == 0 iff val == 0 iff val is not invertible <=> |
| 111 | + // ==> 0 = val * (1 - flag) + (1 - val * val^{-1}) * flag |
| 112 | + println!("5"); |
| 113 | + eval.add_constraint( |
| 114 | + val.clone() * (E::F::one() - flag.clone()) |
| 115 | + + (E::EF::one() - val.clone() * maybe_inverse_val) * flag.clone(), |
| 116 | + ); |
| 117 | + |
| 118 | + // Assert new pc. |
| 119 | + // The relative branch when taken is obvious. |
| 120 | + println!("5.1"); |
| 121 | + let jmp_target_if_taken = |
| 122 | + &Selector::select(&op_type, [&imm, &(pc.clone() + imm.clone())]); |
| 123 | + |
| 124 | + println!("5.2"); |
| 125 | + let new_pc = eval.next_trace_mask(); |
| 126 | + |
| 127 | + println!("6"); |
| 128 | + eval.add_constraint( |
| 129 | + new_pc.clone() - Selector::select(&(E::F::one()-flag), [&(pc + E::F::one()), jmp_target_if_taken]), |
| 130 | + ); |
| 131 | + |
| 132 | + // Yield final state. |
| 133 | + let new_state = [new_pc, ap, fp]; |
| 134 | + println!("7"); |
| 135 | + eval.add_to_relation(RelationEntry::new( |
| 136 | + &self.state_lookup, |
| 137 | + -E::EF::one(), |
| 138 | + &new_state, |
| 139 | + )); |
| 140 | + |
| 141 | + println!("8"); |
| 142 | + eval.finalize_logup_in_pairs(); |
| 143 | + println!("9"); |
| 144 | + eval |
| 145 | + } |
| 146 | +} |
| 147 | + |
| 148 | +#[derive(Copy, Clone, Serialize, Deserialize)] |
| 149 | +pub struct Claim { |
| 150 | + pub n_rows: usize, |
| 151 | +} |
| 152 | + |
| 153 | +impl Claim { |
| 154 | + pub fn log_sizes(&self) -> TreeVec<Vec<u32>> { |
| 155 | + let log_size = std::cmp::max(self.n_rows.next_power_of_two().ilog2(), LOG_N_LANES); |
| 156 | + let preprocessed_log_sizes = vec![log_size]; |
| 157 | + let trace_log_sizes = vec![log_size; N_TRACE_CELLS]; |
| 158 | + let interaction_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE * 3]; |
| 159 | + TreeVec::new(vec![ |
| 160 | + preprocessed_log_sizes, |
| 161 | + trace_log_sizes, |
| 162 | + interaction_log_sizes, |
| 163 | + ]) |
| 164 | + } |
| 165 | + |
| 166 | + pub fn mix_into(&self, channel: &mut impl Channel) { |
| 167 | + channel.mix_u64(self.n_rows as u64); |
| 168 | + } |
| 169 | +} |
| 170 | + |
| 171 | +#[derive(Clone, Serialize, Deserialize)] |
| 172 | +pub struct InteractionClaim { |
| 173 | + pub log_size: u32, |
| 174 | + pub claimed_sum: SecureField, |
| 175 | +} |
| 176 | +impl InteractionClaim { |
| 177 | + pub fn mix_into(&self, channel: &mut impl Channel) { |
| 178 | + channel.mix_felts(&[self.claimed_sum]); |
| 179 | + } |
| 180 | +} |
0 commit comments