Skip to content

Commit c7527b3

Browse files
committed
write final pos back
1 parent c006da1 commit c7527b3

File tree

4 files changed

+41
-11
lines changed

4 files changed

+41
-11
lines changed

extensions/native/circuit/src/poseidon2/air.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,6 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
728728
should_permute,
729729
write_sponge_state,
730730
write_final_idx,
731-
final_idx,
732731
input_register_1,
733732
input_register_2,
734733
input_register_3,
@@ -830,6 +829,16 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
830829
}
831830

832831
for i in 0..CHUNK {
832+
builder
833+
.when(multi_observe_row)
834+
.assert_bool(aux_after_start[i]);
835+
builder
836+
.when(multi_observe_row)
837+
.assert_bool(aux_before_end[i]);
838+
builder
839+
.when(multi_observe_row)
840+
.when(is_first)
841+
.assert_zero(aux_read_enabled[i]);
833842
builder
834843
.when(multi_observe_row)
835844
.assert_eq(aux_after_start[i] * aux_before_end[i], aux_read_enabled[i]);
@@ -889,19 +898,22 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
889898
.assert_eq(*a, *b);
890899
});
891900

892-
/*
901+
builder
902+
.when(multi_observe_row)
903+
.when(aux_read_enabled[CHUNK - 1])
904+
.assert_one(should_permute);
905+
906+
// final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx
907+
let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO
908+
+ (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx;
893909
self.memory_bridge
894910
.write(
895-
MemoryAddress::new(
896-
self.address_space,
897-
input_register_1,
898-
),
911+
MemoryAddress::new(self.address_space, input_register_1),
899912
[final_idx],
900-
start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO,
901-
&write_final_idx
913+
start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute,
914+
&write_final_idx,
902915
)
903916
.eval(builder, multi_observe_row * is_last);
904-
*/
905917

906918
// Field transitions
907919
builder

extensions/native/circuit/src/poseidon2/chip.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,7 @@ where
685685
pos += len;
686686
}
687687
}
688+
final_timestamp_inc += 1; // write back to init_pos_register
688689

689690
let allocated_rows = arena
690691
.alloc(MultiRowLayout::new(NativePoseidon2Metadata {
@@ -810,6 +811,15 @@ where
810811
multi_observe_cols.should_permute = F::ZERO;
811812
cols.inner.inputs.clone_from_slice(&permutation_input);
812813
}
814+
if i == num_chunks - 1 {
815+
let final_idx = F::from_canonical_usize(chunk_end % CHUNK);
816+
tracing_write_native_inplace(
817+
state.memory,
818+
init_pos_register.as_canonical_u32(),
819+
[final_idx],
820+
&mut multi_observe_cols.write_final_idx,
821+
);
822+
}
813823
}
814824
} else {
815825
unreachable!()
@@ -1213,6 +1223,14 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Filler<F, SBOX
12131223
start_timestamp_u32,
12141224
multi_observe_cols.write_sponge_state.as_mut(),
12151225
);
1226+
start_timestamp_u32 += 1;
1227+
}
1228+
if row_idx == num_rows - 1 {
1229+
mem_fill_helper(
1230+
mem_helper,
1231+
start_timestamp_u32,
1232+
multi_observe_cols.write_final_idx.as_mut(),
1233+
);
12161234
}
12171235
}
12181236
}

extensions/native/circuit/src/poseidon2/columns.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,5 +242,4 @@ pub struct MultiObserveCols<T> {
242242

243243
// Final write back and registers
244244
pub write_final_idx: MemoryWriteAuxCols<T, 1>,
245-
pub final_idx: T,
246245
}

extensions/native/circuit/src/poseidon2/execution.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,9 @@ impl<F: PrimeField32, const SBOX_REGISTERS: usize> MeteredExecutor<F>
331331
{
332332
#[inline(always)]
333333
fn metered_pre_compute_size(&self) -> usize {
334-
std::cmp::max(
334+
max3(
335335
size_of::<E2PreCompute<Pos2PreCompute<F, SBOX_REGISTERS>>>(),
336+
size_of::<E2PreCompute<MultiObservePreCompute<F, SBOX_REGISTERS>>>(),
336337
size_of::<E2PreCompute<VerifyBatchPreCompute<F, SBOX_REGISTERS>>>(),
337338
)
338339
}

0 commit comments

Comments
 (0)