Skip to content

Commit 4e64fb2

Browse files
authored
[Feat] Native Multiple Observe for Poseidon2-based Challenger (#3)
1 parent 7046d8a commit 4e64fb2

File tree

12 files changed

+163
-17
lines changed

12 files changed

+163
-17
lines changed

crates/vm/src/arch/segment.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,16 +281,18 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
281281
Some(SysPhantom::CtStart) =>
282282
{
283283
#[cfg(feature = "bench-metrics")]
284-
metrics
285-
.cycle_tracker
286-
.start(dsl_instr.cloned().unwrap_or("Default".to_string()))
284+
metrics.cycle_tracker.start(
285+
dsl_instr.cloned().unwrap_or("Default".to_string()),
286+
metrics.cycle_count,
287+
)
287288
}
288289
Some(SysPhantom::CtEnd) =>
289290
{
290291
#[cfg(feature = "bench-metrics")]
291-
metrics
292-
.cycle_tracker
293-
.end(dsl_instr.cloned().unwrap_or("Default".to_string()))
292+
metrics.cycle_tracker.end(
293+
dsl_instr.cloned().unwrap_or("Default".to_string()),
294+
metrics.cycle_count,
295+
)
294296
}
295297
_ => {}
296298
}

crates/vm/src/metrics/cycle_tracker/mod.rs

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
/// Stats for a nested span in the execution segment that is tracked by the [`CycleTracker`].
2+
#[derive(Clone, Debug, Default)]
3+
pub struct SpanInfo {
4+
/// The name of the span.
5+
tag: String,
6+
/// The cycle count at which the span starts.
7+
start: usize,
8+
}
9+
110
#[derive(Clone, Debug, Default)]
211
pub struct CycleTracker {
312
/// Stack of span names, with most recent at the end
4-
stack: Vec<String>,
13+
stack: Vec<SpanInfo>,
14+
/// Depth of the stack.
15+
depth: usize,
516
}
617

718
impl CycleTracker {
@@ -11,23 +22,33 @@ impl CycleTracker {
1122

1223
/// Starts a new cycle tracker span for the given name.
1324
/// If a span already exists for the given name, it ends the existing span and pushes a new one to the vec.
14-
pub fn start(&mut self, mut name: String) {
25+
pub fn start(&mut self, mut name: String, cycles_count: usize) {
1526
// hack to remove "CT-" prefix
1627
if name.starts_with("CT-") {
1728
name = name.split_off(3);
1829
}
19-
self.stack.push(name);
30+
self.stack.push(SpanInfo {
31+
tag: name.clone(),
32+
start: cycles_count,
33+
});
34+
let padding = "│ ".repeat(self.depth);
35+
tracing::info!("{}┌╴{}", padding, name);
36+
self.depth += 1;
2037
}
2138

2239
/// Ends the cycle tracker span for the given name.
2340
/// If no span exists for the given name, it panics.
24-
pub fn end(&mut self, mut name: String) {
41+
pub fn end(&mut self, mut name: String, cycles_count: usize) {
2542
// hack to remove "CT-" prefix
2643
if name.starts_with("CT-") {
2744
name = name.split_off(3);
2845
}
29-
let stack_top = self.stack.pop();
30-
assert_eq!(stack_top.unwrap(), name, "Stack top does not match name");
46+
let SpanInfo { tag, start } = self.stack.pop().unwrap();
47+
assert_eq!(tag, name, "Stack top does not match name");
48+
self.depth -= 1;
49+
let padding = "│ ".repeat(self.depth);
50+
let span_cycles = cycles_count - start;
51+
tracing::info!("{}└╴{} cycles", padding, span_cycles);
3152
}
3253

3354
/// Ends the current cycle tracker span.
@@ -37,7 +58,11 @@ impl CycleTracker {
3758

3859
/// Get full name of span with all parent names separated by ";" in flamegraph format
3960
pub fn get_full_name(&self) -> String {
40-
self.stack.join(";")
61+
self.stack
62+
.iter()
63+
.map(|span_info| span_info.tag.clone())
64+
.collect::<Vec<String>>()
65+
.join(";")
4166
}
4267
}
4368

extensions/native/circuit/src/extension.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ impl<F: PrimeField32> VmExtension<F> for Native {
200200
VerifyBatchOpcode::VERIFY_BATCH.global_opcode(),
201201
Poseidon2Opcode::PERM_POS2.global_opcode(),
202202
Poseidon2Opcode::COMP_POS2.global_opcode(),
203+
Poseidon2Opcode::MULTI_OBSERVE.global_opcode(),
203204
],
204205
)?;
205206

extensions/native/circuit/src/poseidon2/chip.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use openvm_circuit::{
99
use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
1010
use openvm_native_compiler::{
1111
conversion::AS,
12-
Poseidon2Opcode::{COMP_POS2, PERM_POS2},
12+
Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE},
1313
VerifyBatchOpcode::VERIFY_BATCH,
1414
};
1515
use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip};
@@ -485,6 +485,43 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> InstructionExecutor<F>
485485
initial_log_height: initial_log_height as usize,
486486
top_level,
487487
});
488+
} else if instruction.opcode == MULTI_OBSERVE.global_opcode() {
489+
let &Instruction {
490+
a: output_register,
491+
b: input_register_1,
492+
c: input_register_2,
493+
d: data_address_space,
494+
e: register_address_space,
495+
f: input_register_3,
496+
..
497+
} = instruction;
498+
499+
let (_, sponge_ptr) = memory.read_cell(register_address_space, output_register);
500+
let (_, arr_ptr) = memory.read_cell(register_address_space, input_register_2);
501+
502+
let init_pos_read = memory.read_cell(register_address_space, input_register_1);
503+
let mut pos = init_pos_read.1.as_canonical_u32() as usize;
504+
505+
let len_read = memory.read_cell(register_address_space, input_register_3);
506+
let len = len_read.1.as_canonical_u32() as usize;
507+
508+
for i in 0..len {
509+
let mod_pos = pos % CHUNK;
510+
let n_read = memory.read_cell(data_address_space, arr_ptr + F::from_canonical_usize(i));
511+
let n_f = n_read.1;
512+
513+
memory.write_cell(data_address_space, sponge_ptr + F::from_canonical_usize(mod_pos), n_f);
514+
pos += 1;
515+
516+
if pos % CHUNK == 0 {
517+
let (_, sponge_state) = memory.read::<{CHUNK * 2}>(data_address_space, sponge_ptr);
518+
let output = self.subchip.permute(sponge_state);
519+
memory.write::<{CHUNK * 2}>(data_address_space, sponge_ptr, std::array::from_fn(|i| output[i]));
520+
}
521+
}
522+
523+
let mod_pos = pos % CHUNK;
524+
memory.write_cell(register_address_space, input_register_1, F::from_canonical_usize(mod_pos));
488525
} else {
489526
unreachable!()
490527
}
@@ -501,6 +538,8 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> InstructionExecutor<F>
501538
String::from("PERM_POS2")
502539
} else if opcode == COMP_POS2.global_opcode().as_usize() {
503540
String::from("COMP_POS2")
541+
} else if opcode == MULTI_OBSERVE.global_opcode().as_usize() {
542+
String::from("MULTI_OBSERVE")
504543
} else {
505544
unreachable!("unsupported opcode: {}", opcode)
506545
}

extensions/native/circuit/src/poseidon2/tests.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,8 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester<BabyBearBlak
434434
}
435435
PERM_POS2 => {
436436
tester.write(e, lhs, data);
437-
}
437+
},
438+
MULTI_OBSERVE => {}
438439
}
439440

440441
tester.execute(&mut chip, &instruction);
@@ -449,6 +450,7 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester<BabyBearBlak
449450
let actual = tester.read::<{ 2 * CHUNK }>(e, dst);
450451
assert_eq!(hash, actual);
451452
}
453+
MULTI_OBSERVE => {}
452454
}
453455
}
454456
tester.build().load(chip).finalize()

extensions/native/compiler/src/asm/compiler.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,12 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
489489
DslIr::HintBitsF(var, len) => {
490490
self.push(AsmInstruction::HintBits(var.fp(), len), debug_info);
491491
}
492+
DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => {
493+
self.push(
494+
AsmInstruction::Poseidon2MultiObserve(dst.fp(), init_pos.fp(), arr_ptr.fp(), len.get_var().fp()),
495+
debug_info,
496+
);
497+
},
492498
DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) {
493499
(Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push(
494500
AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()),

extensions/native/compiler/src/asm/instruction.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ pub enum AsmInstruction<F, EF> {
108108
/// Halt.
109109
Halt,
110110

111+
/// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation
112+
/// (sponge_state, init_pos, arr_ptr, len)
113+
/// Returns the final index position of hash sponge
114+
Poseidon2MultiObserve(i32, i32, i32, i32),
115+
111116
/// Perform a Poseidon2 permutation on state starting at address `lhs`
112117
/// and store new state at `rhs`.
113118
/// (a, b) are pointers to (lhs, rhs).
@@ -331,6 +336,9 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
331336
AsmInstruction::Trap => write!(f, "trap"),
332337
AsmInstruction::Halt => write!(f, "halt"),
333338
AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len),
339+
AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => {
340+
write!(f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", dst, init_pos, arr, len)
341+
}
334342
AsmInstruction::Poseidon2Permute(dst, lhs) => {
335343
write!(f, "poseidon2_permute ({})fp, ({})fp", dst, lhs)
336344
}

extensions/native/compiler/src/constraints/halo2/compiler.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,11 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
492492
}
493493
DslIr::CycleTrackerStart(_name) => {
494494
#[cfg(feature = "bench-metrics")]
495-
cell_tracker.start(_name);
495+
cell_tracker.start(_name, 0);
496496
}
497497
DslIr::CycleTrackerEnd(_name) => {
498498
#[cfg(feature = "bench-metrics")]
499-
cell_tracker.end(_name);
499+
cell_tracker.end(_name, 0);
500500
}
501501
DslIr::CircuitPublish(val, index) => {
502502
public_values[index] = vars[&val.0];

extensions/native/compiler/src/conversion/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,18 @@ fn convert_instruction<F: PrimeField32, EF: ExtensionField<F>>(
440440
AS::Native,
441441
AS::Native,
442442
)],
443+
AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![
444+
Instruction {
445+
opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE),
446+
a: i32_f(dst),
447+
b: i32_f(init),
448+
c: i32_f(arr),
449+
d: AS::Native.to_field(),
450+
e: AS::Native.to_field(),
451+
f: i32_f(len),
452+
g: F::ZERO,
453+
}
454+
],
443455
AsmInstruction::Poseidon2Compress(dst, src1, src2) => vec![inst(
444456
options.opcode_with_offset(Poseidon2Opcode::COMP_POS2),
445457
i32_f(dst),

extensions/native/compiler/src/ir/instructions.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ pub enum DslIr<C: Config> {
196196
/// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should only
197197
/// be used when target is a circuit.
198198
CircuitPoseidon2Permute([Var<C::N>; 3]),
199+
/// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations (output = p2_multi_observe(array, els)).
200+
Poseidon2MultiObserve(
201+
Ptr<C::N>, // sponge_state
202+
Var<C::N>, // initial input_ptr position
203+
Ptr<C::N>, // input array (base elements)
204+
Usize<C::N>, // len of els
205+
),
199206

200207
// Miscellaneous instructions.
201208
/// Prints a variable.

0 commit comments

Comments
 (0)