Skip to content

Commit c006da1

Browse files
authored
Feat: tracegen for multi_observe (#15)
* wip * finish
1 parent 1451be1 commit c006da1

File tree

6 files changed

+118
-3
lines changed

6 files changed

+118
-3
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

extensions/native/circuit/cuda/include/native/poseidon2.cuh

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,43 @@ template <typename T> struct SimplePoseidonSpecificCols {
6262
MemoryWriteAuxCols<T, CHUNK> write_data_2;
6363
};
6464

65+
template <typename T> struct MultiObserveCols {
66+
T pc;
67+
T final_timestamp_increment;
68+
T state_ptr;
69+
T input_ptr;
70+
T init_pos;
71+
T len;
72+
T input_register_1;
73+
T input_register_2;
74+
T input_register_3;
75+
T output_register;
76+
T is_first;
77+
T is_last;
78+
T curr_len;
79+
T start_idx;
80+
T end_idx;
81+
T aux_after_start[CHUNK];
82+
T aux_before_end[CHUNK];
83+
T aux_read_enabled[CHUNK];
84+
MemoryReadAuxCols<T> read_data[CHUNK];
85+
MemoryWriteAuxCols<T, 1> write_data[CHUNK];
86+
T data[CHUNK];
87+
T should_permute;
88+
MemoryWriteAuxCols<T, CHUNK * 2> write_sponge_state;
89+
MemoryWriteAuxCols<T, 1> write_final_idx;
90+
T final_idx;
91+
};
92+
6593
template <typename T> constexpr T constexpr_max(T a, T b) { return a > b ? a : b; }
6694

6795
constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max(
6896
sizeof(TopLevelSpecificCols<uint8_t>),
6997
constexpr_max(
7098
sizeof(InsideRowSpecificCols<uint8_t>),
71-
sizeof(SimplePoseidonSpecificCols<uint8_t>)
99+
constexpr_max(
100+
sizeof(SimplePoseidonSpecificCols<uint8_t>),
101+
sizeof(MultiObserveCols<uint8_t>)
102+
)
72103
)
73104
);

extensions/native/circuit/cuda/src/poseidon2.cu

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
2222
T incorporate_sibling;
2323
T inside_row;
2424
T simple;
25+
T multi_observe_row;
2526

2627
T end_inside_row;
2728
T end_top_level;
@@ -38,7 +39,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
3839
};
3940

4041
__device__ void mem_fill_base(
41-
MemoryAuxColsFactory mem_helper,
42+
MemoryAuxColsFactory &mem_helper,
4243
uint32_t timestamp,
4344
RowSlice base_aux
4445
) {
@@ -58,6 +59,8 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
5859
) {
5960
if (row[COL_INDEX(Cols, simple)] == Fp::one()) {
6061
fill_simple_chunk(row, range_checker, timestamp_max_bits);
62+
} else if (row[COL_INDEX(Cols, multi_observe_row)] == Fp::one()) {
63+
fill_multi_observe_chunk(row, range_checker, timestamp_max_bits);
6164
} else {
6265
fill_verify_batch_chunk(row, range_checker, timestamp_max_bits);
6366
}
@@ -335,6 +338,66 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
335338
}
336339
}
337340
}
341+
342+
__device__ static void fill_multi_observe_chunk(
343+
RowSlice row,
344+
VariableRangeChecker range_checker,
345+
uint32_t timestamp_max_bits
346+
) {
347+
MemoryAuxColsFactory mem_helper(range_checker, timestamp_max_bits);
348+
Poseidon2Row head_row(row);
349+
uint32_t num_rows = head_row.export_col()[0].asUInt32();
350+
351+
for (uint32_t idx = 0; idx < num_rows; ++idx) {
352+
RowSlice curr_row = row.shift_row(idx);
353+
fill_inner(curr_row);
354+
fill_multi_observe_specific(curr_row, mem_helper);
355+
}
356+
}
357+
358+
__device__ static void fill_multi_observe_specific(
359+
RowSlice row,
360+
MemoryAuxColsFactory &mem_helper
361+
) {
362+
RowSlice specific = row.slice_from(COL_INDEX(Cols, specific));
363+
if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) {
364+
uint32_t very_start_timestamp =
365+
row[COL_INDEX(Cols, very_first_timestamp)].asUInt32();
366+
for (uint32_t i = 0; i < 4; ++i) {
367+
mem_fill_base(
368+
mem_helper,
369+
very_start_timestamp + i,
370+
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base))
371+
);
372+
}
373+
} else {
374+
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
375+
uint32_t chunk_start =
376+
specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32();
377+
uint32_t chunk_end =
378+
specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32();
379+
for (uint32_t j = chunk_start; j < chunk_end; ++j) {
380+
mem_fill_base(
381+
mem_helper,
382+
start_timestamp,
383+
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
384+
);
385+
mem_fill_base(
386+
mem_helper,
387+
start_timestamp + 1,
388+
specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base))
389+
);
390+
start_timestamp += 2;
391+
}
392+
if (chunk_end >= CHUNK) {
393+
mem_fill_base(
394+
mem_helper,
395+
start_timestamp,
396+
specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base))
397+
);
398+
}
399+
}
400+
}
338401
};
339402

340403
template <size_t SBOX_REGISTERS>

extensions/native/circuit/src/poseidon2/cuda.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ impl<const SBOX_REGISTERS: usize> Chip<DenseRecordArena, GpuBackend>
5353
chunk_start.push(row_idx as u32);
5454
if cols.simple.is_one() {
5555
row_idx += 1;
56+
} else if cols.multi_observe_row.is_one() {
57+
let num_rows = cols.inner.export.as_canonical_u32() as usize;
58+
row_idx += num_rows;
5659
} else {
5760
let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize;
5861
let non_inside_start = start + (num_non_inside_row - 1) * width;

extensions/native/recursion/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ repository.workspace = true
88

99
[dependencies]
1010
openvm-stark-backend = { workspace = true }
11+
openvm-cuda-backend = { workspace = true, optional = true }
1112
openvm-native-circuit = { workspace = true, features = ["test-utils"] }
1213
openvm-native-compiler = { workspace = true }
1314
openvm-native-compiler-derive = { workspace = true }
@@ -58,4 +59,4 @@ parallel = ["openvm-stark-backend/parallel"]
5859
mimalloc = ["openvm-stark-backend/mimalloc"]
5960
jemalloc = ["openvm-stark-backend/jemalloc"]
6061
nightly-features = ["openvm-circuit/nightly-features"]
61-
cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda"]
62+
cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda", "dep:openvm-cuda-backend"]

extensions/native/recursion/tests/recursion.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use openvm_circuit::{
66
},
77
utils::{air_test_impl, TestStarkEngine},
88
};
9+
#[cfg(feature = "cuda")]
10+
use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine;
911
use openvm_native_circuit::{
1012
execute_program_with_config, test_native_config, NativeBuilder, NativeConfig,
1113
};
@@ -211,8 +213,22 @@ fn test_multi_observe() {
211213
config.system.memory_config.max_access_adapter_n = 16;
212214

213215
let vb = NativeBuilder::default();
216+
#[cfg(not(feature = "cuda"))]
214217
air_test_impl::<BabyBearPoseidon2Engine, _>(fri_params, vb, config, program, vec![], 1, true)
215218
.unwrap();
219+
#[cfg(feature = "cuda")]
220+
{
221+
air_test_impl::<GpuBabyBearPoseidon2Engine, _>(
222+
fri_params,
223+
vb,
224+
config,
225+
program,
226+
vec![],
227+
1,
228+
true,
229+
)
230+
.unwrap();
231+
}
216232
}
217233

218234
fn build_test_program<C: Config>(builder: &mut Builder<C>) {

0 commit comments

Comments
 (0)