Skip to content

Commit b11662a

Browse files
committed
tracegen
1 parent 102bd1f commit b11662a

File tree

9 files changed

+342
-18
lines changed

9 files changed

+342
-18
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#pragma once
2+
3+
#include "primitives/constants.h"
4+
#include "system/memory/offline_checker.cuh"
5+
6+
using namespace native;
7+
8+
template <typename T> struct HeaderSpecificCols {
9+
T pc;
10+
T registers[5];
11+
MemoryReadAuxCols<T> read_records[7];
12+
MemoryWriteAuxCols<T, EXT_DEG> write_records;
13+
};
14+
15+
template <typename T> struct ProdSpecificCols {
16+
T data_ptr;
17+
T p[EXT_DEG * 2];
18+
MemoryReadAuxCols<T> read_records[2];
19+
T p_evals[EXT_DEG];
20+
MemoryWriteAuxCols<T, EXT_DEG> write_record;
21+
T eval_rlc[EXT_DEG];
22+
};
23+
24+
template <typename T> struct LogupSpecificCols {
25+
T data_ptr;
26+
T pq[EXT_DEG * 4];
27+
MemoryReadAuxCols<T> read_records[2];
28+
T p_evals[EXT_DEG];
29+
T q_evals[EXT_DEG];
30+
MemoryWriteAuxCols<T, EXT_DEG> write_records[2];
31+
T eval_rlc[EXT_DEG];
32+
};
33+
34+
template <typename T> constexpr T constexpr_max(T a, T b) {
35+
return a > b ? a : b;
36+
}
37+
38+
constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max(
39+
sizeof(HeaderSpecificCols<uint8_t>),
40+
constexpr_max(sizeof(ProdSpecificCols<uint8_t>), sizeof(LogupSpecificCols<uint8_t>))
41+
);
42+
43+
template <typename T> struct NativeSumcheckCols {
44+
T header_row;
45+
T prod_row;
46+
T logup_row;
47+
T is_end;
48+
49+
T prod_continued;
50+
T logup_continued;
51+
52+
T prod_in_round_evaluation;
53+
T prod_next_round_evaluation;
54+
T logup_in_round_evaluation;
55+
T logup_next_round_evaluation;
56+
57+
T prod_acc;
58+
T logup_acc;
59+
60+
T first_timestamp;
61+
T start_timestamp;
62+
T last_timestamp;
63+
64+
T register_ptrs[5];
65+
66+
T ctx[EXT_DEG * 2];
67+
68+
T prod_nested_len;
69+
T logup_nested_len;
70+
71+
T curr_prod_n;
72+
T curr_logup_n;
73+
74+
T alpha[EXT_DEG];
75+
T challenges[EXT_DEG * 4];
76+
77+
T max_round;
78+
T within_round_limit;
79+
T should_acc;
80+
81+
T eval_acc[EXT_DEG];
82+
83+
T specific[COL_SPECIFIC_WIDTH];
84+
};
85+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#pragma once
2+
3+
#include "primitives/trace_access.h"
4+
#include "system/memory/controller.cuh"
5+
6+
__device__ __forceinline__ void mem_fill_base(
7+
MemoryAuxColsFactory &mem_helper,
8+
uint32_t timestamp,
9+
RowSlice base_aux
10+
) {
11+
uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32();
12+
mem_helper.fill(base_aux, prev, timestamp);
13+
}

extensions/native/circuit/cuda/src/poseidon2.cu

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "poseidon2-air/columns.cuh"
33
#include "poseidon2-air/params.cuh"
44
#include "poseidon2-air/tracegen.cuh"
5+
#include "native/utils.cuh"
56
#include "primitives/trace_access.h"
67
#include "system/memory/controller.cuh"
78

@@ -38,15 +39,6 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
3839
T specific[COL_SPECIFIC_WIDTH];
3940
};
4041

41-
__device__ void mem_fill_base(
42-
MemoryAuxColsFactory &mem_helper,
43-
uint32_t timestamp,
44-
RowSlice base_aux
45-
) {
46-
uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32();
47-
mem_helper.fill(base_aux, prev, timestamp);
48-
}
49-
5042
template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
5143
template <typename T> using Cols = NativePoseidon2Cols<T, SBOX_REGISTERS>;
5244
using Poseidon2Row =
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include "launcher.cuh"
2+
#include "native/sumcheck.cuh"
3+
#include "native/utils.cuh"
4+
#include "primitives/trace_access.h"
5+
#include "system/memory/controller.cuh"
6+
7+
using namespace native;
8+
9+
__device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_helper) {
10+
RowSlice specific = row.slice_from(COL_INDEX(NativeSumcheckCols, specific));
11+
uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32();
12+
13+
if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) {
14+
for (uint32_t i = 0; i < 7; ++i) {
15+
mem_fill_base(
16+
mem_helper,
17+
start_timestamp + i,
18+
specific.slice_from(COL_INDEX(HeaderSpecificCols, read_records[i].base))
19+
);
20+
}
21+
uint32_t last_timestamp = row[COL_INDEX(NativeSumcheckCols, last_timestamp)].asUInt32();
22+
mem_fill_base(
23+
mem_helper,
24+
last_timestamp - 1,
25+
specific.slice_from(COL_INDEX(HeaderSpecificCols, write_records.base))
26+
);
27+
} else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) {
28+
mem_fill_base(
29+
mem_helper,
30+
start_timestamp,
31+
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base))
32+
);
33+
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
34+
mem_fill_base(
35+
mem_helper,
36+
start_timestamp + 1,
37+
specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[1].base))
38+
);
39+
mem_fill_base(
40+
mem_helper,
41+
start_timestamp + 2,
42+
specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base))
43+
);
44+
}
45+
} else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) {
46+
mem_fill_base(
47+
mem_helper,
48+
start_timestamp,
49+
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base))
50+
);
51+
if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) {
52+
mem_fill_base(
53+
mem_helper,
54+
start_timestamp + 1,
55+
specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[1].base))
56+
);
57+
mem_fill_base(
58+
mem_helper,
59+
start_timestamp + 2,
60+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base))
61+
);
62+
mem_fill_base(
63+
mem_helper,
64+
start_timestamp + 3,
65+
specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base))
66+
);
67+
}
68+
}
69+
}
70+
71+
__global__ void native_sumcheck_tracegen(
72+
Fp *trace,
73+
size_t height,
74+
size_t width,
75+
const Fp *records,
76+
size_t rows_used,
77+
uint32_t *range_checker_ptr,
78+
uint32_t range_checker_num_bins,
79+
uint32_t timestamp_max_bits
80+
) {
81+
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
82+
if (idx >= height) {
83+
return;
84+
}
85+
86+
RowSlice row(trace + idx, height);
87+
if (idx < rows_used) {
88+
const Fp *record = records + idx * width;
89+
for (uint32_t col = 0; col < width; ++col) {
90+
row[col] = record[col];
91+
}
92+
MemoryAuxColsFactory mem_helper(
93+
VariableRangeChecker(range_checker_ptr, range_checker_num_bins), timestamp_max_bits
94+
);
95+
fill_sumcheck_specific(row, mem_helper);
96+
} else {
97+
row.fill_zero(0, width);
98+
COL_WRITE_VALUE(row, NativeSumcheckCols, is_end, Fp::one());
99+
}
100+
}
101+
102+
extern "C" int _native_sumcheck_tracegen(
103+
Fp *d_trace,
104+
size_t height,
105+
size_t width,
106+
const Fp *d_records,
107+
size_t rows_used,
108+
uint32_t *d_range_checker,
109+
uint32_t range_checker_num_bins,
110+
uint32_t timestamp_max_bits
111+
) {
112+
assert((height & (height - 1)) == 0);
113+
assert(width == sizeof(NativeSumcheckCols<uint8_t>));
114+
auto [grid, block] = kernel_launch_params(height);
115+
native_sumcheck_tracegen<<<grid, block>>>(
116+
d_trace,
117+
height,
118+
width,
119+
d_records,
120+
rows_used,
121+
d_range_checker,
122+
range_checker_num_bins,
123+
timestamp_max_bits
124+
);
125+
return CHECK_KERNEL();
126+
}

extensions/native/circuit/src/cuda_abi.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,44 @@ pub mod poseidon2_cuda {
235235
}
236236
}
237237

238+
pub mod sumcheck_cuda {
239+
use super::*;
240+
241+
extern "C" {
242+
pub fn _native_sumcheck_tracegen(
243+
d_trace: *mut F,
244+
height: usize,
245+
width: usize,
246+
d_records: *const F,
247+
rows_used: usize,
248+
d_range_checker: *mut u32,
249+
range_checker_max_bins: u32,
250+
timestamp_max_bits: u32,
251+
) -> i32;
252+
}
253+
254+
pub unsafe fn tracegen(
255+
d_trace: &DeviceBuffer<F>,
256+
height: usize,
257+
width: usize,
258+
d_records: &DeviceBuffer<F>,
259+
rows_used: usize,
260+
d_range_checker: &DeviceBuffer<F>,
261+
timestamp_max_bits: u32,
262+
) -> Result<(), CudaError> {
263+
CudaError::from_result(_native_sumcheck_tracegen(
264+
d_trace.as_mut_ptr(),
265+
height,
266+
width,
267+
d_records.as_ptr(),
268+
rows_used,
269+
d_range_checker.as_mut_ptr() as *mut u32,
270+
d_range_checker.len() as u32,
271+
timestamp_max_bits,
272+
))
273+
}
274+
}
275+
238276
pub mod native_loadstore_cuda {
239277
use super::*;
240278

extensions/native/circuit/src/extension/cuda.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::{
1717
jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu},
1818
loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu},
1919
poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu},
20+
sumcheck::{air::NativeSumcheckAir, NativeSumcheckChipGpu},
2021
CastFExtension, GpuBackend, Native,
2122
};
2223

@@ -75,6 +76,10 @@ impl VmProverExtension<GpuBabyBearPoseidon2Engine, DenseRecordArena, Native>
7576
let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits);
7677
inventory.add_executor_chip(poseidon2);
7778

79+
inventory.next_air::<NativeSumcheckAir>()?;
80+
let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits);
81+
inventory.add_executor_chip(sumcheck);
82+
7883
Ok(())
7984
}
8085
}

extensions/native/circuit/src/sumcheck/chip.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::borrow::BorrowMut;
33
use openvm_circuit::{
44
arch::{
55
CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor,
6-
RecordArena, TraceFiller, VmChipWrapper, VmStateMut,
6+
RecordArena, SizedRecord, TraceFiller, VmChipWrapper, VmStateMut,
77
},
88
system::{
99
memory::{online::TracingMemory, MemoryAuxColsFactory},
@@ -76,14 +76,24 @@ impl<'a, F: PrimeField32>
7676
// Each instruction record consists solely of some number of contiguously
7777
// stored NativeSumcheckCols<...> structs, each of which corresponds to a
7878
// single trace row. Trace fillers don't actually need to know how many rows
79-
// each instruction uses, and can thus treat each NativePoseidon2Cols<...>
79+
// each instruction uses, and can thus treat each NativeSumcheckCols<...>
8080
// as a single record.
8181
NativeSumcheckRecordLayout {
8282
metadata: NativeSumcheckMetadata { num_rows: 1 },
8383
}
8484
}
8585
}
8686

87+
impl<F: PrimeField32> SizedRecord<NativeSumcheckRecordLayout> for NativeSumcheckRecordMut<'_, F> {
88+
fn size(layout: &NativeSumcheckRecordLayout) -> usize {
89+
layout.metadata.num_rows * size_of::<NativeSumcheckCols<F>>()
90+
}
91+
92+
fn alignment(_layout: &NativeSumcheckRecordLayout) -> usize {
93+
align_of::<NativeSumcheckCols<F>>()
94+
}
95+
}
96+
8797
#[derive(derive_new::new, Copy, Clone)]
8898
pub struct NativeSumcheckExecutor;
8999

0 commit comments

Comments
 (0)