diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 56833b7e26..765ef2d3a4 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -281,16 +281,18 @@ impl> ExecutionSegment { Some(SysPhantom::CtStart) => { #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .start(dsl_instr.cloned().unwrap_or("Default".to_string())) + metrics.cycle_tracker.start( + dsl_instr.cloned().unwrap_or("Default".to_string()), + metrics.cycle_count, + ) } Some(SysPhantom::CtEnd) => { #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .end(dsl_instr.cloned().unwrap_or("Default".to_string())) + metrics.cycle_tracker.end( + dsl_instr.cloned().unwrap_or("Default".to_string()), + metrics.cycle_count, + ) } _ => {} } diff --git a/crates/vm/src/metrics/cycle_tracker/mod.rs b/crates/vm/src/metrics/cycle_tracker/mod.rs index 06cbe09193..1815bac404 100644 --- a/crates/vm/src/metrics/cycle_tracker/mod.rs +++ b/crates/vm/src/metrics/cycle_tracker/mod.rs @@ -1,7 +1,18 @@ +/// Stats for a nested span in the execution segment that is tracked by the [`CycleTracker`]. +#[derive(Clone, Debug, Default)] +pub struct SpanInfo { + /// The name of the span. + tag: String, + /// The cycle count at which the span starts. + start: usize, +} + #[derive(Clone, Debug, Default)] pub struct CycleTracker { /// Stack of span names, with most recent at the end - stack: Vec, + stack: Vec, + /// Depth of the stack. + depth: usize, } impl CycleTracker { @@ -11,23 +22,33 @@ impl CycleTracker { /// Starts a new cycle tracker span for the given name. /// If a span already exists for the given name, it ends the existing span and pushes a new one to the vec. - pub fn start(&mut self, mut name: String) { + pub fn start(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - self.stack.push(name); + self.stack.push(SpanInfo { + tag: name.clone(), + start: cycles_count, + }); + let padding = "│ ".repeat(self.depth); + tracing::info!("{}┌╴{}", padding, name); + self.depth += 1; } /// Ends the cycle tracker span for the given name. /// If no span exists for the given name, it panics. - pub fn end(&mut self, mut name: String) { + pub fn end(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - let stack_top = self.stack.pop(); - assert_eq!(stack_top.unwrap(), name, "Stack top does not match name"); + let SpanInfo { tag, start } = self.stack.pop().unwrap(); + assert_eq!(tag, name, "Stack top does not match name"); + self.depth -= 1; + let padding = "│ ".repeat(self.depth); + let span_cycles = cycles_count - start; + tracing::info!("{}└╴{} cycles", padding, span_cycles); } /// Ends the current cycle tracker span. @@ -37,7 +58,11 @@ impl CycleTracker { /// Get full name of span with all parent names separated by ";" in flamegraph format pub fn get_full_name(&self) -> String { - self.stack.join(";") + self.stack + .iter() + .map(|span_info| span_info.tag.clone()) + .collect::>() + .join(";") } } diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index fbaf11ff74..81cc1066f2 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -200,6 +200,7 @@ impl VmExtension for Native { VerifyBatchOpcode::VERIFY_BATCH.global_opcode(), Poseidon2Opcode::PERM_POS2.global_opcode(), Poseidon2Opcode::COMP_POS2.global_opcode(), + Poseidon2Opcode::MULTI_OBSERVE.global_opcode(), ], )?; diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 426b089a9c..f26db6ffe7 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -9,7 +9,7 @@ use openvm_circuit::{ use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip}; @@ -485,6 +485,43 @@ impl InstructionExecutor initial_log_height: initial_log_height as usize, top_level, }); + } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { + let &Instruction { + a: output_register, + b: input_register_1, + c: input_register_2, + d: data_address_space, + e: register_address_space, + f: input_register_3, + .. + } = instruction; + + let (_, sponge_ptr) = memory.read_cell(register_address_space, output_register); + let (_, arr_ptr) = memory.read_cell(register_address_space, input_register_2); + + let init_pos_read = memory.read_cell(register_address_space, input_register_1); + let mut pos = init_pos_read.1.as_canonical_u32() as usize; + + let len_read = memory.read_cell(register_address_space, input_register_3); + let len = len_read.1.as_canonical_u32() as usize; + + for i in 0..len { + let mod_pos = pos % CHUNK; + let n_read = memory.read_cell(data_address_space, arr_ptr + F::from_canonical_usize(i)); + let n_f = n_read.1; + + memory.write_cell(data_address_space, sponge_ptr + F::from_canonical_usize(mod_pos), n_f); + pos += 1; + + if pos % CHUNK == 0 { + let (_, sponge_state) = memory.read::<{CHUNK * 2}>(data_address_space, sponge_ptr); + let output = self.subchip.permute(sponge_state); + memory.write::<{CHUNK * 2}>(data_address_space, sponge_ptr, std::array::from_fn(|i| output[i])); + } + } + + let mod_pos = pos % CHUNK; + memory.write_cell(register_address_space, input_register_1, F::from_canonical_usize(mod_pos)); } else { unreachable!() } @@ -501,6 +538,8 @@ impl InstructionExecutor String::from("PERM_POS2") } else if opcode == COMP_POS2.global_opcode().as_usize() { String::from("COMP_POS2") + } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { + String::from("MULTI_OBSERVE") } else { unreachable!("unsupported opcode: {}", opcode) } diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index 32a0e483a3..6d703c549b 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -434,7 +434,8 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { tester.write(e, lhs, data); - } + }, + MULTI_OBSERVE => {} } tester.execute(&mut chip, &instruction); @@ -449,6 +450,7 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester(e, dst); assert_eq!(hash, actual); } + MULTI_OBSERVE => {} } } tester.build().load(chip).finalize() diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index cff4a47288..11480a300a 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -489,6 +489,12 @@ impl + TwoAdicField> AsmCo DslIr::HintBitsF(var, len) => { self.push(AsmInstruction::HintBits(var.fp(), len), debug_info); } + DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { + self.push( + AsmInstruction::Poseidon2MultiObserve(dst.fp(), init_pos.fp(), arr_ptr.fp(), len.get_var().fp()), + debug_info, + ); + }, DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) { (Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push( AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()), diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index bc5ce3d021..9f406afe3b 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -108,6 +108,11 @@ pub enum AsmInstruction { /// Halt. Halt, + /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation + /// (sponge_state, init_pos, arr_ptr, len) + /// Returns the final index position of hash sponge + Poseidon2MultiObserve(i32, i32, i32, i32), + /// Perform a Poseidon2 permutation on state starting at address `lhs` /// and store new state at `rhs`. /// (a, b) are pointers to (lhs, rhs). @@ -331,6 +336,9 @@ impl> AsmInstruction { AsmInstruction::Trap => write!(f, "trap"), AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), + AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { + write!(f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", dst, init_pos, arr, len) + } AsmInstruction::Poseidon2Permute(dst, lhs) => { write!(f, "poseidon2_permute ({})fp, ({})fp", dst, lhs) } diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index 404dc4cecd..a427a149a0 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -492,11 +492,11 @@ impl Halo2ConstraintCompiler { } DslIr::CycleTrackerStart(_name) => { #[cfg(feature = "bench-metrics")] - cell_tracker.start(_name); + cell_tracker.start(_name, 0); } DslIr::CycleTrackerEnd(_name) => { #[cfg(feature = "bench-metrics")] - cell_tracker.end(_name); + cell_tracker.end(_name, 0); } DslIr::CircuitPublish(val, index) => { public_values[index] = vars[&val.0]; diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 0a202b69fa..d9cf82fd43 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -440,6 +440,18 @@ fn convert_instruction>( AS::Native, AS::Native, )], + AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![ + Instruction { + opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE), + a: i32_f(dst), + b: i32_f(init), + c: i32_f(arr), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(len), + g: F::ZERO, + } + ], AsmInstruction::Poseidon2Compress(dst, src1, src2) => vec![inst( options.opcode_with_offset(Poseidon2Opcode::COMP_POS2), i32_f(dst), diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index f3c3fd86f7..08a96b6f07 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -196,6 +196,13 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should only /// be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations (output = p2_multi_observe(array, els)). + Poseidon2MultiObserve( + Ptr, // sponge_state + Var, // initial input_ptr position + Ptr, // input array (base elements) + Usize, // len of els + ), // Miscellaneous instructions. /// Prints a variable. diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index 12ec526c89..ee9bc0d87d 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -1,6 +1,8 @@ use openvm_native_compiler_derive::iter_zip; use openvm_stark_backend::p3_field::FieldAlgebra; +use crate::ir::Variable; + use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; pub const DIGEST_SIZE: usize = 8; @@ -8,6 +10,47 @@ pub const HASH_RATE: usize = 8; pub const PERMUTATION_WIDTH: usize = 16; impl Builder { + /// Extends native VM ability to observe multiple base elements in one opcode operation + /// Absorbs elements sequentially at the RATE portion of sponge state and performs as many permutations as necessary. + /// Returns the index position of the next input_ptr. + /// + /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) + pub fn poseidon2_multi_observe( + &mut self, + sponge_state: &Array>, + input_ptr: Ptr, + arr: &Array>, + ) -> Usize { + let buffer_size: Var = Var::uninit(self); + self.assign(&buffer_size, C::N::from_canonical_usize(HASH_RATE)); + + match sponge_state { + Array::Fixed(_) => { + panic!("Poseidon2 permutation is not allowed on fixed arrays"); + } + Array::Dyn(sponge_ptr, _) => { + match arr { + Array::Fixed(_) => { + panic!("Base elements input must be dynamic"); + } + Array::Dyn(ptr, len) => { + let init_pos: Var = Var::uninit(self); + self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + + self.operations.push(DslIr::Poseidon2MultiObserve( + *sponge_ptr, + init_pos, + *ptr, + len.clone(), + )); + + Usize::Var(init_pos) + } + } + } + } + } + /// Applies the Poseidon2 permutation to the given array. /// /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index ef28b37139..66c786fbd9 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -184,6 +184,7 @@ pub enum NativePhantom { pub enum Poseidon2Opcode { PERM_POS2, COMP_POS2, + MULTI_OBSERVE, } /// Opcodes for FRI opening proofs.