Skip to content

Commit 3e1c98c

Browse files
committed
wip
1 parent 1451be1 commit 3e1c98c

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

extensions/native/circuit/cuda/src/poseidon2.cu

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
2222
T incorporate_sibling;
2323
T inside_row;
2424
T simple;
25+
T multi_observe_row;
2526

2627
T end_inside_row;
2728
T end_top_level;
@@ -38,7 +39,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
3839
};
3940

4041
__device__ void mem_fill_base(
41-
MemoryAuxColsFactory mem_helper,
42+
MemoryAuxColsFactory &mem_helper,
4243
uint32_t timestamp,
4344
RowSlice base_aux
4445
) {
@@ -58,6 +59,8 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
5859
) {
5960
if (row[COL_INDEX(Cols, simple)] == Fp::one()) {
6061
fill_simple_chunk(row, range_checker, timestamp_max_bits);
62+
} else if (row[COL_INDEX(Cols, multi_observe_row)] == Fp::one()) {
63+
fill_multi_observe_chunk(row, range_checker, timestamp_max_bits);
6164
} else {
6265
fill_verify_batch_chunk(row, range_checker, timestamp_max_bits);
6366
}
@@ -335,6 +338,66 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
335338
}
336339
}
337340
}
341+
342+
__device__ static void fill_multi_observe_chunk(
343+
RowSlice row,
344+
VariableRangeChecker range_checker,
345+
uint32_t timestamp_max_bits
346+
) {
347+
MemoryAuxColsFactory mem_helper(range_checker, timestamp_max_bits);
348+
Poseidon2Row head_row(row);
349+
uint32_t num_rows = head_row.export_col()[0].asUInt32();
350+
351+
for (uint32_t idx = 0; idx < num_rows; ++idx) {
352+
RowSlice curr_row = row.shift_row(idx);
353+
fill_inner(curr_row);
354+
fill_multi_observe_specific(curr_row, mem_helper);
355+
}
356+
}
357+
358+
__device__ static void fill_multi_observe_specific(
359+
RowSlice row,
360+
MemoryAuxColsFactory &mem_helper
361+
) {
362+
RowSlice specific = row.slice_from(COL_INDEX(Cols, specific));
363+
if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) {
364+
uint32_t very_start_timestamp =
365+
row[COL_INDEX(Cols, very_first_timestamp)].asUInt32();
366+
for (uint32_t i = 0; i < 4; ++i) {
367+
mem_fill_base(
368+
mem_helper,
369+
very_start_timestamp + i,
370+
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base))
371+
);
372+
}
373+
} else {
374+
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
375+
uint32_t chunk_start =
376+
specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32();
377+
uint32_t chunk_end =
378+
specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32();
379+
for (uint32_t j = chunk_start; j < chunk_end; ++j) {
380+
mem_fill_base(
381+
mem_helper,
382+
start_timestamp,
383+
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
384+
);
385+
mem_fill_base(
386+
mem_helper,
387+
start_timestamp + 1,
388+
specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base))
389+
);
390+
start_timestamp += 2;
391+
}
392+
if (chunk_end >= CHUNK) {
393+
mem_fill_base(
394+
mem_helper,
395+
start_timestamp,
396+
specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base))
397+
);
398+
}
399+
}
400+
}
338401
};
339402

340403
template <size_t SBOX_REGISTERS>

0 commit comments

Comments
 (0)