Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions extensions/native/circuit/cuda/src/poseidon2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -355,31 +355,52 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) {
uint32_t very_start_timestamp =
row[COL_INDEX(Cols, very_first_timestamp)].asUInt32();
for (uint32_t i = 0; i < 4; ++i) {
// 3 register reads at timestamps +0, +1, +2
for (uint32_t i = 0; i < 3; ++i) {
mem_fill_base(
mem_helper,
very_start_timestamp + i,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base))
);
}
// 1 context array read at timestamp +3
mem_fill_base(
mem_helper,
very_start_timestamp + 3,
specific.slice_from(COL_INDEX(MultiObserveCols, read_ctx.base))
);
// 1 hint_id register read at timestamp +4 (reuse spare read_data[3] on head row)
mem_fill_base(
mem_helper,
very_start_timestamp + 4,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[3].base))
);
} else {
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
uint32_t chunk_start =
specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32();
uint32_t chunk_end =
specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32();
// is_hint = ctx[2]
uint32_t is_hint =
specific[COL_INDEX(MultiObserveCols, ctx[2])].asUInt32();
uint32_t ts_per_element = 2 - is_hint;
for (uint32_t j = chunk_start; j < chunk_end; ++j) {
if (!is_hint) {
// Non-hint mode: fill read_data aux
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
);
}
// Write timestamp: start_timestamp + (1 - is_hint) for non-hint, start_timestamp for hint
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
start_timestamp + (1 - is_hint),
specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base))
);
start_timestamp += 2;
start_timestamp += ts_per_element;
}
if (chunk_end >= CHUNK) {
mem_fill_base(
Expand Down
11 changes: 6 additions & 5 deletions extensions/native/circuit/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,6 @@ where
inventory.add_executor_chip(fri_reduced_opening);

inventory.next_air::<NativePoseidon2Air<Val<SC>, 1>>()?;
let poseidon2 = NativePoseidon2Chip::<_, 1>::new(
NativePoseidon2Filler::new(Poseidon2Config::default()),
mem_helper.clone(),
);
inventory.add_executor_chip(poseidon2);

let hint_bus = inventory.airs().system().hint_bridge.hint_bus();
let hint_space_provider = Arc::new(HintSpaceProviderChip::new(
Expand All @@ -379,6 +374,12 @@ where
timestamp_max_bits,
));

let poseidon2 = NativePoseidon2Chip::<_, 1>::new(
NativePoseidon2Filler::new(Poseidon2Config::default(), hint_space_provider.clone()),
mem_helper.clone(),
);
inventory.add_executor_chip(poseidon2);

inventory.next_air::<HintSpaceProviderAir>()?;
inventory.add_periphery_chip(hint_space_provider.clone());

Expand Down
126 changes: 93 additions & 33 deletions extensions/native/circuit/src/poseidon2/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,10 +713,16 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
let &MultiObserveCols {
pc,
final_timestamp_increment,
state_ptr_register,
ctx_register,
input_ptr_register,
hint_id_register,
state_ptr,
ctx_ptr,
input_ptr,
init_pos,
len,
hint_id,
ctx,
read_ctx,
is_first,
is_last,
curr_len,
Expand All @@ -731,35 +737,38 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
should_permute,
write_sponge_state,
write_final_idx,
input_register_1,
input_register_2,
input_register_3,
output_register,
} = multi_observe_specific;

// Alias context values
let init_pos = ctx[0];
let len = ctx[1];
let is_hint = ctx[2];

builder.when(multi_observe_row).assert_bool(is_first);
builder.when(multi_observe_row).assert_bool(is_last);
builder.when(multi_observe_row).assert_bool(should_permute);
builder.when(multi_observe_row).assert_bool(is_hint);

self.execution_bridge
.execute_and_increment_pc(
AB::F::from_canonical_usize(MULTI_OBSERVE.global_opcode().as_usize()),
[
output_register.into(),
input_register_1.into(),
input_register_2.into(),
state_ptr_register.into(),
ctx_register.into(),
input_ptr_register.into(),
self.address_space.into(),
self.address_space.into(),
input_register_3.into(),
hint_id_register.into(),
],
ExecutionState::new(pc, very_first_timestamp),
final_timestamp_increment,
)
.eval(builder, multi_observe_row * is_first);

// Head row: 3 register reads + 1 context array read + 1 hint_id register read
self.memory_bridge
.read(
MemoryAddress::new(self.address_space, output_register),
MemoryAddress::new(self.address_space, state_ptr_register),
[state_ptr],
very_first_timestamp,
&read_data[0],
Expand All @@ -768,50 +777,82 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>

self.memory_bridge
.read(
MemoryAddress::new(self.address_space, input_register_1),
[init_pos],
MemoryAddress::new(self.address_space, ctx_register),
[ctx_ptr],
very_first_timestamp + AB::F::ONE,
&read_data[1],
)
.eval(builder, multi_observe_row * is_first);

self.memory_bridge
.read(
MemoryAddress::new(self.address_space, input_register_2),
MemoryAddress::new(self.address_space, input_ptr_register),
[input_ptr],
very_first_timestamp + AB::F::TWO,
&read_data[2],
)
.eval(builder, multi_observe_row * is_first);

// Read context array: [init_pos, len, is_hint, reserved] from ctx_ptr
self.memory_bridge
.read(
MemoryAddress::new(self.address_space, input_register_3),
[len],
MemoryAddress::new(self.address_space, ctx_ptr),
ctx,
very_first_timestamp + AB::F::from_canonical_usize(3),
&read_ctx,
)
.eval(builder, multi_observe_row * is_first);

// Read hint_id from register (reuse spare read_data[3] on head row)
self.memory_bridge
.read(
MemoryAddress::new(self.address_space, hint_id_register),
[hint_id],
very_first_timestamp + AB::F::from_canonical_usize(4),
&read_data[3],
)
.eval(builder, multi_observe_row * is_first);

// ts_per_element = 2 - is_hint (non-hint: read+write=2, hint: write-only=1)
let is_hint_expr: AB::Expr = is_hint.into();
let ts_per_element: AB::Expr = AB::Expr::TWO - is_hint_expr.clone();
for i in 0..CHUNK {
let i_var = AB::F::from_canonical_usize(i);
let i_var: AB::Expr = AB::F::from_canonical_usize(i).into();
let start_idx_expr: AB::Expr = start_idx.into();
let element_start_ts: AB::Expr =
start_timestamp.into() + (i_var.clone() - start_idx_expr.clone()) * ts_per_element.clone();

// Non-hint mode: read from memory
self.memory_bridge
.read(
MemoryAddress::new(
self.address_space,
input_ptr + curr_len + i_var - start_idx,
input_ptr + curr_len + i_var.clone() - start_idx_expr.clone(),
),
[data[i]],
start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO,
element_start_ts.clone(),
&read_data[i],
)
.eval(builder, multi_observe_row * aux_read_enabled[i]);
.eval(
builder,
multi_observe_row * aux_read_enabled[i] * (AB::Expr::ONE - is_hint_expr.clone()),
);

// Hint mode: lookup from hint space
self.hint_bridge.lookup(
builder,
hint_id,
curr_len + i_var.clone() - start_idx_expr.clone(),
data[i],
multi_observe_row * aux_read_enabled[i] * is_hint_expr.clone(),
);

// Write to sponge state (always, for both modes)
self.memory_bridge
.write(
MemoryAddress::new(self.address_space, state_ptr + i_var),
[data[i]],
start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE,
element_start_ts + (AB::Expr::ONE - is_hint_expr.clone()),
&write_data[i],
)
.eval(builder, multi_observe_row * aux_read_enabled[i]);
Expand Down Expand Up @@ -885,7 +926,7 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
.write(
MemoryAddress::new(self.address_space, state_ptr),
full_sponge_output,
start_timestamp + (end_idx - start_idx) * AB::F::TWO,
start_timestamp + (end_idx - start_idx) * (AB::Expr::TWO - is_hint_expr.clone()),
&write_sponge_state,
)
.eval(builder, multi_observe_row * should_permute);
Expand All @@ -909,11 +950,12 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
// final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx
let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO
+ (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx;
// Write final_idx back to ctx[0] (ctx_ptr address)
self.memory_bridge
.write(
MemoryAddress::new(self.address_space, input_register_1),
MemoryAddress::new(self.address_space, ctx_ptr),
[final_idx],
start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute,
start_timestamp + (end_idx - start_idx) * (AB::Expr::TWO - is_hint_expr) + should_permute,
&write_final_idx,
)
.eval(builder, multi_observe_row * is_last);
Expand Down Expand Up @@ -962,41 +1004,59 @@ impl<AB: InteractionBuilder, const SBOX_REGISTERS: usize> Air<AB>
builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(init_pos, next_multi_observe_specific.init_pos);
.assert_eq(init_pos, next_multi_observe_specific.ctx[0]);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(len, next_multi_observe_specific.len);
.assert_eq(len, next_multi_observe_specific.ctx[1]);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(
state_ptr_register,
next_multi_observe_specific.state_ptr_register,
);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(
input_register_1,
next_multi_observe_specific.input_register_1,
ctx_register,
next_multi_observe_specific.ctx_register,
);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(
input_register_2,
next_multi_observe_specific.input_register_2,
input_ptr_register,
next_multi_observe_specific.input_ptr_register,
);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(ctx_ptr, next_multi_observe_specific.ctx_ptr);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(hint_id, next_multi_observe_specific.hint_id);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(
input_register_3,
next_multi_observe_specific.input_register_3,
hint_id_register,
next_multi_observe_specific.hint_id_register,
);

builder
.when(next.multi_observe_row)
.when(not(next_multi_observe_specific.is_first))
.assert_eq(output_register, next_multi_observe_specific.output_register);
.assert_eq(is_hint, next_multi_observe_specific.ctx[2]);

// Timestamp constraints
builder
Expand Down
Loading
Loading