Skip to content

Commit d73de3e

Browse files
authored
Fix: write final pos back (#16)
* write final pos back * apply change for gpu * fix that multi_observe works as observe_slice * larger test case
1 parent c006da1 commit d73de3e

File tree

9 files changed

+86
-40
lines changed

9 files changed

+86
-40
lines changed

extensions/native/circuit/cuda/include/native/poseidon2.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ template <typename T> struct MultiObserveCols {
8787
T should_permute;
8888
MemoryWriteAuxCols<T, CHUNK * 2> write_sponge_state;
8989
MemoryWriteAuxCols<T, 1> write_final_idx;
90-
T final_idx;
9190
};
9291

9392
template <typename T> constexpr T constexpr_max(T a, T b) { return a > b ? a : b; }

extensions/native/circuit/cuda/src/poseidon2.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,14 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
395395
start_timestamp,
396396
specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base))
397397
);
398+
start_timestamp += 1;
399+
}
400+
if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) {
401+
mem_fill_base(
402+
mem_helper,
403+
start_timestamp,
404+
specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base))
405+
);
398406
}
399407
}
400408
}

extensions/native/circuit/src/poseidon2/air.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,6 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
728728
should_permute,
729729
write_sponge_state,
730730
write_final_idx,
731-
final_idx,
732731
input_register_1,
733732
input_register_2,
734733
input_register_3,
@@ -830,6 +829,16 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
830829
}
831830

832831
for i in 0..CHUNK {
832+
builder
833+
.when(multi_observe_row)
834+
.assert_bool(aux_after_start[i]);
835+
builder
836+
.when(multi_observe_row)
837+
.assert_bool(aux_before_end[i]);
838+
builder
839+
.when(multi_observe_row)
840+
.when(is_first)
841+
.assert_zero(aux_read_enabled[i]);
833842
builder
834843
.when(multi_observe_row)
835844
.assert_eq(aux_after_start[i] * aux_before_end[i], aux_read_enabled[i]);
@@ -889,19 +898,22 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
889898
.assert_eq(*a, *b);
890899
});
891900

892-
/*
901+
builder
902+
.when(multi_observe_row)
903+
.when(aux_read_enabled[CHUNK - 1])
904+
.assert_one(should_permute);
905+
906+
// final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx
907+
let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO
908+
+ (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx;
893909
self.memory_bridge
894910
.write(
895-
MemoryAddress::new(
896-
self.address_space,
897-
input_register_1,
898-
),
911+
MemoryAddress::new(self.address_space, input_register_1),
899912
[final_idx],
900-
start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO,
901-
&write_final_idx
913+
start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute,
914+
&write_final_idx,
902915
)
903916
.eval(builder, multi_observe_row * is_last);
904-
*/
905917

906918
// Field transitions
907919
builder

extensions/native/circuit/src/poseidon2/chip.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,7 @@ where
685685
pos += len;
686686
}
687687
}
688+
final_timestamp_inc += 1; // write back to init_pos_register
688689

689690
let allocated_rows = arena
690691
.alloc(MultiRowLayout::new(NativePoseidon2Metadata {
@@ -810,6 +811,15 @@ where
810811
multi_observe_cols.should_permute = F::ZERO;
811812
cols.inner.inputs.clone_from_slice(&permutation_input);
812813
}
814+
if i == num_chunks - 1 {
815+
let final_idx = F::from_canonical_usize(chunk_end % CHUNK);
816+
tracing_write_native_inplace(
817+
state.memory,
818+
init_pos_register.as_canonical_u32(),
819+
[final_idx],
820+
&mut multi_observe_cols.write_final_idx,
821+
);
822+
}
813823
}
814824
} else {
815825
unreachable!()
@@ -1213,6 +1223,14 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Filler<F, SBOX
12131223
start_timestamp_u32,
12141224
multi_observe_cols.write_sponge_state.as_mut(),
12151225
);
1226+
start_timestamp_u32 += 1;
1227+
}
1228+
if row_idx == num_rows - 1 {
1229+
mem_fill_helper(
1230+
mem_helper,
1231+
start_timestamp_u32,
1232+
multi_observe_cols.write_final_idx.as_mut(),
1233+
);
12161234
}
12171235
}
12181236
}

extensions/native/circuit/src/poseidon2/columns.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,5 +242,4 @@ pub struct MultiObserveCols<T> {
242242

243243
// Final write back and registers
244244
pub write_final_idx: MemoryWriteAuxCols<T, 1>,
245-
pub final_idx: T,
246245
}

extensions/native/circuit/src/poseidon2/execution.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,9 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> MeteredExecutor<F>
331331
{
332332
#[inline(always)]
333333
fn metered_pre_compute_size(&self) -> usize {
334-
std::cmp::max(
334+
max3(
335335
size_of::<E2PreCompute<Pos2PreCompute<F, SBOX_REGISTERS>>>(),
336+
size_of::<E2PreCompute<MultiObservePreCompute<F, SBOX_REGISTERS>>>(),
336337
size_of::<E2PreCompute<VerifyBatchPreCompute<F, SBOX_REGISTERS>>>(),
337338
)
338339
}
@@ -613,6 +614,7 @@ unsafe fn execute_multi_observe_e12_impl<
613614
pos += len;
614615
}
615616
}
617+
let final_idx = observation_chunks.last().map(|(_, end)| *end % CHUNK);
616618

617619
height += 1;
618620
let mut input_idx = 0;
@@ -631,6 +633,13 @@ unsafe fn execute_multi_observe_e12_impl<
631633

632634
height += 1;
633635
}
636+
if let Some(final_idx) = final_idx {
637+
exec_state.vm_write::<F, 1>(
638+
NATIVE_AS,
639+
pre_compute.init_pos_register,
640+
&[F::from_canonical_usize(final_idx)],
641+
);
642+
}
634643
*pc = pc.wrapping_add(DEFAULT_PC_STEP);
635644
*instret += 1;
636645

extensions/native/compiler/src/ir/poseidon.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ impl<C: Config> Builder<C> {
4242
len.clone(),
4343
));
4444

45+
// automatically updated by Poseidon2MultiObserve operation
4546
Usize::Var(init_pos)
4647
}
4748
},

extensions/native/recursion/src/challenger/duplex.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,24 @@ impl<C: Config> DuplexChallengerVariable<C> {
7777
}
7878
}
7979

80+
// Observes multiple elements from an array.
81+
// This is equivalent to calling `observe` multiple times, but more efficient.
82+
pub fn observe_slice_opt(&self, builder: &mut Builder<C>, arr: &Array<C, Felt<C::F>>) {
83+
builder.if_ne(arr.len(), Usize::from(0)).then(|builder| {
84+
let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr);
85+
86+
builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone());
87+
builder.if_ne(next_pos, Usize::from(0)).then_or_else(
88+
|builder| {
89+
builder.assign(&self.output_ptr, self.io_empty_ptr);
90+
},
91+
|builder| {
92+
builder.assign(&self.output_ptr, self.io_full_ptr);
93+
},
94+
);
95+
});
96+
}
97+
8098
fn sample(&self, builder: &mut Builder<C>) -> Felt<C::F> {
8199
builder
82100
.if_ne(self.input_ptr.address, self.io_empty_ptr.address)

extensions/native/recursion/tests/recursion.rs

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ use openvm_native_compiler::{
1515
asm::{AsmBuilder, AsmCompiler},
1616
conversion::{convert_program, CompilerOptions},
1717
ir::{Array, Builder, Config, Felt},
18-
prelude::Usize,
1918
};
2019
use openvm_native_recursion::{
21-
challenger::{duplex::DuplexChallengerVariable, CanObserveVariable},
20+
challenger::{duplex::DuplexChallengerVariable, CanObserveVariable, CanSampleVariable},
2221
testing_utils::inner::run_recursive_test,
2322
};
2423
use openvm_stark_backend::{
@@ -192,7 +191,6 @@ fn test_multi_observe() {
192191
compiler.build(builder.operations);
193192
let asm_code = compiler.code();
194193

195-
// let program = Program::from_instructions(&instructions);
196194
let program: Program<_> = convert_program(asm_code, compilation_options);
197195

198196
let poseidon2_max_constraint_degree = 3;
@@ -232,17 +230,12 @@ fn test_multi_observe() {
232230
}
233231

234232
fn build_test_program<C: Config>(builder: &mut Builder<C>) {
235-
let sample_lens: Vec<usize> = vec![10, 2, 1, 3, 20];
233+
let sample_lens: Vec<usize> = vec![10, 2, 1, 0, 3, 20, 200, 400];
236234

237235
let mut rng = create_seeded_rng();
238-
let mut challenger = DuplexChallengerVariable::new(builder);
239236

240-
// Observe a setup label
241-
let label_f: Vec<u64> = vec![128, 3098, 192, 394, 1662, 928, 374, 281, 598, 182, 475, 729];
242-
for n in label_f {
243-
let f: Felt<C::F> = builder.constant(C::F::from_canonical_u64(n));
244-
challenger.observe(builder, f);
245-
}
237+
let mut c1 = DuplexChallengerVariable::new(builder);
238+
let mut c2 = DuplexChallengerVariable::new(builder);
246239

247240
for l in sample_lens {
248241
let sample_input: Array<C, Felt<C::F>> = builder.dyn_array(l);
@@ -251,24 +244,13 @@ fn build_test_program<C: Config>(builder: &mut Builder<C>) {
251244
builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32));
252245
});
253246

254-
let next_input_ptr = builder.poseidon2_multi_observe(
255-
&challenger.sponge_state,
256-
challenger.input_ptr,
257-
&sample_input,
258-
);
247+
c1.observe_slice_opt(builder, &sample_input);
248+
c2.observe_slice(builder, sample_input);
249+
250+
let e1 = c1.sample(builder);
251+
let e2 = c2.sample(builder);
259252

260-
builder.assign(
261-
&challenger.input_ptr,
262-
challenger.io_empty_ptr + next_input_ptr.clone(),
263-
);
264-
builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else(
265-
|builder| {
266-
builder.assign(&challenger.output_ptr, challenger.io_empty_ptr);
267-
},
268-
|builder| {
269-
builder.assign(&challenger.output_ptr, challenger.io_full_ptr);
270-
},
271-
);
253+
builder.assert_felt_eq(e1, e2);
272254
}
273255
builder.halt();
274256
}

0 commit comments

Comments
 (0)