Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/prover/src/cairo_air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use tracing::{span, Level};

use crate::components::memory::component::{Claim, InteractionClaim};
use crate::components::memory::{ClaimGenerator, Component as MemoryComponent, Eval};
use crate::input::instructions::VmState;
use crate::input::CairoInput;
use crate::relations::MemoryRelation;
use crate::utils::types::CasmState;

#[derive(Serialize, Deserialize)]
pub struct CairoProof<H: MerkleHasher> {
Expand All @@ -36,8 +36,8 @@ pub struct CairoProof<H: MerkleHasher> {
pub struct CairoClaim {
// Common claim values.
pub public_memory: Vec<(M31, QM31)>,
pub initial_state: VmState,
pub final_state: VmState,
pub initial_state: CasmState,
pub final_state: CasmState,

pub addr_to_value: Claim,
// pub ret: Vec<ret_opcode::Claim>,
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/components/add_mul_opcode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ mod tests {
use super::*;
use crate::components::add_mul_opcode::component::INSTRUCTION_BASE;
use crate::components::memory;
use crate::input::instructions::VmState;
use crate::relations;
use crate::utils::types::CasmState;

#[test]
fn test_add_mul_opcode() {
Expand Down Expand Up @@ -79,15 +79,15 @@ mod tests {
let claim_generator = ClaimGenerator::new(
chain!(
vec![
VmState {
CasmState {
pc: 0,
ap: 2,
fp: 4,
};
128
],
vec![
VmState {
CasmState {
pc: 1,
ap: 2,
fp: 4,
Expand Down
19 changes: 6 additions & 13 deletions crates/prover/src/components/add_mul_opcode/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,19 @@ use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
use super::component::{Claim, InteractionClaim, INSTRUCTION_BASE};
use crate::components::add_mul_opcode::component::N_TRACE_COLUMNS;
use crate::components::memory;
use crate::input::instructions::VmState;
use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE};
use crate::utils::prover::decode_opcode;
use crate::utils::types::{CasmState, PackedCasmState};
use crate::utils::{Selector, SelectorTrait};

const N_MEMORY_LOOKUPS: usize = 4;
const N_STATE_LOOKUPS: usize = 2;

// TODO(Ohad): take from prover_types and remove.
pub struct PackedVmState {
pub pc: PackedM31,
pub ap: PackedM31,
pub fp: PackedM31,
}

pub struct ClaimGenerator {
pub inputs: Vec<PackedVmState>,
pub inputs: Vec<PackedCasmState>,
}
impl ClaimGenerator {
pub fn new(mut inputs: Vec<VmState>) -> Self {
pub fn new(mut inputs: Vec<CasmState>) -> Self {
assert!(!inputs.is_empty());

// TODO(spapini): Split to multiple components.
Expand All @@ -44,7 +37,7 @@ impl ClaimGenerator {
let inputs = inputs
.into_iter()
.array_chunks::<N_LANES>()
.map(|chunk| PackedVmState {
.map(|chunk| PackedCasmState {
pc: PackedM31::from_array(std::array::from_fn(|i| {
M31::from_u32_unchecked(chunk[i].pc)
})),
Expand Down Expand Up @@ -157,7 +150,7 @@ impl InteractionClaimGenerator {
}

fn write_trace_simd(
inputs: &[PackedVmState],
inputs: &[PackedCasmState],
memory_trace_generator: &memory::ClaimGenerator,
) -> (
Vec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
Expand Down Expand Up @@ -199,7 +192,7 @@ fn write_trace_simd(
// | State (3) | flags (5) | offsets (3) | addrs (3) | values (3 * 4) |
fn write_trace_row(
trace: &mut [Col<SimdBackend, M31>],
input: &PackedVmState,
input: &PackedCasmState,
row_index: usize,
interaction_claim_generator: &mut InteractionClaimGenerator,
memory_trace_generator: &memory::ClaimGenerator,
Expand Down
13 changes: 2 additions & 11 deletions crates/prover/src/components/ret_opcode/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,19 @@ use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

use super::component::{Claim, InteractionClaim, RET_INSTRUCTION};
use crate::components::memory;
use crate::input::instructions::VmState;
use crate::relations::{MemoryRelation, StateRelation, N_MEMORY_ELEMS, STATE_SIZE};
use crate::utils::types::{CasmState, PackedCasmState};

const N_TRACE_COLUMNS: usize = 5;

const N_MEMORY_LOOKUPS: usize = 3;
const N_STATE_LOOKUPS: usize = 2;

// TODO(Ohad): take from prover_types and remove.
#[derive(Debug, Clone)]
pub struct PackedCasmState {
pub pc: PackedM31,
pub ap: PackedM31,
pub fp: PackedM31,
}

#[derive(Debug)]
pub struct ClaimGenerator {
pub inputs: Vec<PackedCasmState>,
}
impl ClaimGenerator {
pub fn new(mut inputs: Vec<VmState>) -> Self {
pub fn new(mut inputs: Vec<CasmState>) -> Self {
assert!(!inputs.is_empty());

// TODO(spapini): Split to multiple components.
Expand Down
50 changes: 16 additions & 34 deletions crates/prover/src/input/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,7 @@
use serde::{Deserialize, Serialize};

use super::decode::Instruction;
use super::mem::{MemoryBuilder, MemoryValue};
use super::vm_import::TraceEntry;

// TODO(spapini): Move this:
#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
pub struct VmState {
pub pc: u32,
pub ap: u32,
pub fp: u32,
}
impl From<TraceEntry> for VmState {
fn from(entry: TraceEntry) -> Self {
Self {
pc: entry.pc as u32,
ap: entry.ap as u32,
fp: entry.fp as u32,
}
}
}
use crate::utils::types::CasmState;

// TODO(yuval/alonT): consider making the indexing mechanism more explicit in the code).
/// The instructions usage in the input, split to Stwo opcodes.
Expand All @@ -30,44 +12,44 @@ impl From<TraceEntry> for VmState {
/// Note: for the flag "fp/ap", true means fp-based and false means ap-based.
#[derive(Debug, Default)]
pub struct Instructions {
pub initial_state: VmState,
pub final_state: VmState,
pub initial_state: CasmState,
pub final_state: CasmState,

/// ret.
pub ret: Vec<VmState>,
pub ret: Vec<CasmState>,

/// ap += imm.
pub add_ap: Vec<VmState>,
pub add_ap: Vec<CasmState>,

/// jump rel imm.
/// Flags: ap++?.
pub jmp_rel_imm: [Vec<VmState>; 2],
pub jmp_rel_imm: [Vec<CasmState>; 2],

/// jump abs [fp/ap + offset].
/// Flags: fp/ap, ap++?.
pub jmp_abs: [Vec<VmState>; 4],
pub jmp_abs: [Vec<CasmState>; 4],

/// call rel imm.
pub call_rel_imm: Vec<VmState>,
pub call_rel_imm: Vec<CasmState>,

/// call abs [fp/ap + offset].
/// Flags: fp/ap.
pub call_abs: [Vec<VmState>; 2],
pub call_abs: [Vec<CasmState>; 2],

/// jump rel imm if [fp/ap + offset] != 0.
/// Flags: fp/ap, taken?, ap++?.
pub jnz_imm: [Vec<VmState>; 8],
pub jnz_imm: [Vec<CasmState>; 8],

/// - [fp/ap + offset0] = [fp/ap + offset2]
pub mov_mem: Vec<VmState>,
pub mov_mem: Vec<CasmState>,

/// - [fp/ap + offset0] = [[fp/ap + offset1] + offset2]
pub deref: Vec<VmState>,
pub deref: Vec<CasmState>,

/// - [fp/ap + offset0] = imm
pub push_imm: Vec<VmState>,
pub push_imm: Vec<CasmState>,

pub generic: Vec<VmState>,
pub generic: Vec<CasmState>,
}
impl Instructions {
pub fn from_iter(mut iter: impl Iterator<Item = TraceEntry>, mem: &mut MemoryBuilder) -> Self {
Expand All @@ -86,8 +68,8 @@ impl Instructions {
res
}

fn push_instr(&mut self, mem: &mut MemoryBuilder, state: VmState) {
let VmState { ap, fp, pc } = state;
fn push_instr(&mut self, mem: &mut MemoryBuilder, state: CasmState) {
let CasmState { ap, fp, pc } = state;
let instruction = mem.get_inst(pc);
let instruction = Instruction::decode(instruction);
match instruction {
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod component;
pub mod prover;
pub mod types;

use std::ops::{Add, Mul, Sub};

Expand Down
27 changes: 27 additions & 0 deletions crates/prover/src/utils/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use serde::{Deserialize, Serialize};
use stwo_prover::core::backend::simd::m31::PackedM31;

use crate::input::vm_import::TraceEntry;

#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
pub struct CasmState {
pub pc: u32,
pub ap: u32,
pub fp: u32,
}
impl From<TraceEntry> for CasmState {
fn from(entry: TraceEntry) -> Self {
Self {
pc: entry.pc as u32,
ap: entry.ap as u32,
fp: entry.fp as u32,
}
}
}

#[derive(Debug, Clone)]
pub struct PackedCasmState {
pub pc: PackedM31,
pub ap: PackedM31,
pub fp: PackedM31,
}
Loading