Skip to content

Commit 23c956d

Browse files
authored
Allow Reading from Hint Space for MULTI_OBSERVE (#41)
* rebase hint multi observe * adjust degree * fix * debug * adjust * adjust cuda * fix cuda * fix cuda * fix cuda * add debug utilities * add debug flags * remove debug flags * remove debug flag * remove debug flag
1 parent 4a49ffe commit 23c956d

File tree

18 files changed

+534
-157
lines changed

18 files changed

+534
-157
lines changed

crates/circuits/mod-builder/src/utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ pub fn biguint_to_limbs_vec(x: &BigUint, num_limbs: usize) -> Vec<u8> {
1111
.chain(std::iter::repeat(0u8))
1212
.take(num_limbs)
1313
.collect()
14-
}
14+
}

extensions/native/circuit/cuda/include/native/poseidon2.cuh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,17 @@ template <typename T> struct SimplePoseidonSpecificCols {
6565
template <typename T> struct MultiObserveCols {
6666
T pc;
6767
T final_timestamp_increment;
68+
T state_ptr_register;
69+
T ctx_register;
70+
T input_ptr_register;
71+
T hint_id_register;
6872
T state_ptr;
73+
T ctx_ptr;
6974
T input_ptr;
70-
T init_pos;
71-
T len;
72-
T input_register_1;
73-
T input_register_2;
74-
T input_register_3;
75-
T output_register;
75+
T hint_id;
76+
T ctx[4];
77+
MemoryReadAuxCols<T> read_ctx;
78+
T chunk_ts_count;
7679
T is_first;
7780
T is_last;
7881
T curr_len;

extensions/native/circuit/cuda/src/poseidon2.cu

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
2424
T inside_row;
2525
T simple;
2626
T multi_observe_row;
27+
T not_hint_multi_observe;
2728

2829
T end_inside_row;
2930
T end_top_level;
@@ -355,31 +356,57 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
355356
if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) {
356357
uint32_t very_start_timestamp =
357358
row[COL_INDEX(Cols, very_first_timestamp)].asUInt32();
358-
for (uint32_t i = 0; i < 4; ++i) {
359+
for (uint32_t i = 0; i < 3; ++i) {
359360
mem_fill_base(
360361
mem_helper,
361362
very_start_timestamp + i,
362363
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base))
363364
);
364365
}
366+
mem_fill_base(
367+
mem_helper,
368+
very_start_timestamp + 3,
369+
specific.slice_from(COL_INDEX(MultiObserveCols, read_ctx.base))
370+
);
371+
mem_fill_base(
372+
mem_helper,
373+
very_start_timestamp + 4,
374+
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[3].base))
375+
);
376+
377+
// Zero-length MULTI_OBSERVE case: head row is both first and last.
378+
// The final ctx[0] writeback lives at row.start_timestamp.
379+
if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) {
380+
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
381+
mem_fill_base(
382+
mem_helper,
383+
start_timestamp,
384+
specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base))
385+
);
386+
}
365387
} else {
366388
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
367389
uint32_t chunk_start =
368390
specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32();
369391
uint32_t chunk_end =
370392
specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32();
393+
uint32_t is_hint =
394+
specific[COL_INDEX(MultiObserveCols, ctx[2])].asUInt32();
395+
uint32_t ts_per_element = 2 - is_hint;
371396
for (uint32_t j = chunk_start; j < chunk_end; ++j) {
397+
if (!is_hint) {
398+
mem_fill_base(
399+
mem_helper,
400+
start_timestamp,
401+
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
402+
);
403+
}
372404
mem_fill_base(
373405
mem_helper,
374-
start_timestamp,
375-
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
376-
);
377-
mem_fill_base(
378-
mem_helper,
379-
start_timestamp + 1,
406+
start_timestamp + (1 - is_hint),
380407
specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base))
381408
);
382-
start_timestamp += 2;
409+
start_timestamp += ts_per_element;
383410
}
384411
if (chunk_end >= CHUNK) {
385412
mem_fill_base(

extensions/native/circuit/src/extension/cuda.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,29 @@ impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Native>
7575
FriReducedOpeningChipGpu::new(range_checker.clone(), timestamp_max_bits);
7676
inventory.add_executor_chip(fri_reduced_opening);
7777

78-
inventory.next_air::<NativePoseidon2Air<BabyBear, 1>>()?;
79-
let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
80-
inventory.add_executor_chip(poseidon2);
81-
8278
let hint_air: &HintSpaceProviderAir = inventory.next_air::<HintSpaceProviderAir>()?;
79+
let cpu_range_checker = range_checker
80+
.cpu_chip
81+
.clone()
82+
.expect("VariableRangeCheckerChipGPU is expected to be hybrid with cpu_chip");
8383
let cpu_chip = Arc::new(HintSpaceProviderChip::new(
8484
hint_air.hint_bus,
85-
range_checker.clone(),
85+
cpu_range_checker,
8686
timestamp_max_bits,
8787
));
88+
8889
let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone());
8990
inventory.add_periphery_chip(provider_gpu);
9091

92+
inventory.next_air::<NativePoseidon2Air<BabyBear, 1>>()?;
93+
94+
let poseidon2 = NativePoseidon2ChipGpu::<1>::new_with_hint_space_provider(
95+
range_checker.clone(),
96+
timestamp_max_bits,
97+
cpu_chip.clone(),
98+
);
99+
inventory.add_executor_chip(poseidon2);
100+
91101
inventory.next_air::<NativeSumcheckAir>()?;
92102
let sumcheck =
93103
NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip);

extensions/native/circuit/src/extension/mod.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,6 @@ where
271271
);
272272
inventory.add_air(fri_reduced_opening);
273273

274-
let verify_batch = NativePoseidon2Air::<_, 1>::new(
275-
exec_bridge,
276-
memory_bridge,
277-
hint_bridge,
278-
VerifyBatchBus::new(inventory.new_bus_idx()),
279-
Poseidon2Config::default(),
280-
);
281-
inventory.add_air(verify_batch);
282-
283274
let hint_space_provider = HintSpaceProviderAir {
284275
hint_bus: hint_bridge.hint_bus(),
285276
lt_air: IsLtSubAir::new(
@@ -289,6 +280,15 @@ where
289280
};
290281
inventory.add_air(hint_space_provider);
291282

283+
let verify_batch = NativePoseidon2Air::<_, 1>::new(
284+
exec_bridge,
285+
memory_bridge,
286+
hint_bridge,
287+
VerifyBatchBus::new(inventory.new_bus_idx()),
288+
Poseidon2Config::default(),
289+
);
290+
inventory.add_air(verify_batch);
291+
292292
let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge);
293293
inventory.add_air(tower_evaluate);
294294

@@ -365,13 +365,6 @@ where
365365
FriReducedOpeningChip::new(FriReducedOpeningFiller::new(), mem_helper.clone());
366366
inventory.add_executor_chip(fri_reduced_opening);
367367

368-
inventory.next_air::<NativePoseidon2Air<Val<SC>, 1>>()?;
369-
let poseidon2 = NativePoseidon2Chip::<_, 1>::new(
370-
NativePoseidon2Filler::new(Poseidon2Config::default()),
371-
mem_helper.clone(),
372-
);
373-
inventory.add_executor_chip(poseidon2);
374-
375368
let hint_bus = inventory.airs().system().hint_bridge.hint_bus();
376369
let hint_space_provider = Arc::new(HintSpaceProviderChip::new(
377370
hint_bus,
@@ -382,8 +375,17 @@ where
382375
inventory.next_air::<HintSpaceProviderAir>()?;
383376
inventory.add_periphery_chip(hint_space_provider.clone());
384377

378+
inventory.next_air::<NativePoseidon2Air<Val<SC>, 1>>()?;
379+
380+
let poseidon2 = NativePoseidon2Chip::<_, 1>::new(
381+
NativePoseidon2Filler::new(Poseidon2Config::default(), hint_space_provider.clone()),
382+
mem_helper.clone(),
383+
);
384+
inventory.add_executor_chip(poseidon2);
385+
386+
inventory.next_air::<NativeSumcheckAir>()?;
385387
let tower_verify = NativeSumcheckChip::new(
386-
NativeSumcheckFiller::new(hint_space_provider),
388+
NativeSumcheckFiller::new(hint_space_provider.clone()),
387389
mem_helper.clone(),
388390
);
389391
inventory.add_executor_chip(tower_verify);

extensions/native/circuit/src/fri/cuda.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use openvm_cuda_common::copy::MemCopyH2D;
1313
use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
1414

1515
use super::{FriReducedOpeningRecordMut, OVERALL_WIDTH};
16-
use crate::cuda_abi::fri_cuda;
16+
use crate::{
17+
cuda_abi::fri_cuda,
18+
};
1719

1820
#[derive(new)]
1921
pub struct FriReducedOpeningChipGpu {

extensions/native/circuit/src/jal_rangecheck/cuda.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ use openvm_cuda_common::copy::MemCopyH2D;
1010
use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
1111

1212
use super::{JalRangeCheckCols, JalRangeCheckRecord};
13-
use crate::cuda_abi::native_jal_rangecheck_cuda;
13+
use crate::{
14+
cuda_abi::native_jal_rangecheck_cuda,
15+
};
1416

1517
#[derive(new)]
1618
pub struct JalRangeCheckGpu {

0 commit comments

Comments
 (0)