Skip to content

Commit 939cd1e

Browse files
committed
wip9
1 parent cbbe3c1 commit 939cd1e

File tree

2 files changed

+92
-67
lines changed

2 files changed

+92
-67
lines changed

extensions/native/circuit/src/sumcheck/chip.rs

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ use crate::{
2727
};
2828

2929
pub(crate) const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2;
30-
const CURRENT_LAYER_MODE: u32 = 1;
31-
const NEXT_LAYER_MODE: u32 = 0;
30+
pub(crate) const CURRENT_LAYER_MODE: u32 = 1;
31+
pub(crate) const NEXT_LAYER_MODE: u32 = 0;
3232

3333
pub(crate) fn calculate_3d_ext_idx(
3434
inner_inner_len: u32,
@@ -144,13 +144,15 @@ where
144144
.alloc(MultiRowLayout::new(NativeSumcheckMetadata { num_rows }))
145145
.0;
146146

147-
let mut cur_timestamp = 0;
147+
let mut cur_timestamp = state.memory.timestamp();
148148
// head row
149149
let head_row: &mut NativeSumcheckCols<F> = &mut rows[0];
150150
let head_specific: &mut HeaderSpecificCols<F> =
151151
head_row.specific[..HeaderSpecificCols::<F>::width()].borrow_mut();
152152

153153
head_row.header_row = F::ONE;
154+
head_row.first_timestamp = F::from_canonical_u32(cur_timestamp);
155+
head_row.start_timestamp = F::from_canonical_u32(cur_timestamp);
154156

155157
head_specific.pc = F::from_canonical_u32(*state.pc);
156158

@@ -210,6 +212,7 @@ where
210212

211213
// all rows share same register values, ctx, challenges
212214
for row in rows.iter_mut() {
215+
// c1, c2 are same during the entire execution
213216
row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]);
214217
row.alpha = alpha;
215218
row.ctx = ctx;
@@ -231,8 +234,8 @@ where
231234
prod_row.specific[..ProdSpecificCols::<F>::width()].borrow_mut();
232235

233236
prod_row.prod_row = F::ONE;
234-
prod_row.curr_prod_n = F::from_canonical_usize(i);
235-
prod_row.start_timestamp = F::from_canonical_usize(cur_timestamp);
237+
prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1
238+
prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp);
236239

237240
// read max_round
238241
let [max_round]: [F; 1] = tracing_read_native_helper(
@@ -245,8 +248,9 @@ where
245248
prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc);
246249
prod_row.max_round = max_round;
247250

251+
let max_round = max_round.as_canonical_u32();
248252
// round starts from 0
249-
if round < max_round.as_canonical_u32() - 1 {
253+
if round < max_round - 1 {
250254
prod_row.within_round_limit = F::ONE;
251255
let start = calculate_3d_ext_idx(
252256
prod_specs_inner_inner_len,
@@ -296,20 +300,19 @@ where
296300
eval,
297301
&mut prod_specific.write_record,
298302
);
299-
cur_timestamp += 1;
303+
cur_timestamp += 2;
300304

301305
let eval_rlc = FieldExtension::multiply(alpha_acc, eval);
302306
prod_specific.eval_rlc = eval_rlc;
303307

304-
if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 {
308+
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
305309
eval_acc = FieldExtension::add(eval_acc, eval_rlc);
306310
prod_row.should_acc = F::ONE;
307311
prod_row.eval_acc = eval_acc;
308312
}
309313
}
310314

311315
alpha_acc = FieldExtension::multiply(alpha_acc, alpha);
312-
prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc);
313316
}
314317

315318
// logup rows
@@ -318,8 +321,8 @@ where
318321
logup_row.specific[..LogupSpecificCols::<F>::width()].borrow_mut();
319322

320323
logup_row.logup_row = F::ONE;
321-
logup_row.curr_logup_n = F::from_canonical_usize(i);
322-
logup_row.start_timestamp = F::from_canonical_usize(cur_timestamp);
324+
logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1
325+
logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp);
323326

324327
let [max_round]: [F; 1] = tracing_read_native_helper(
325328
state.memory,
@@ -334,7 +337,8 @@ where
334337
logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc);
335338
logup_row.challenges[2 * EXT_DEG..(3 * EXT_DEG)].copy_from_slice(&alpha_denominator);
336339

337-
if round < max_round.as_canonical_u32() - 1 {
340+
let max_round = max_round.as_canonical_u32();
341+
if round < max_round - 1 {
338342
logup_row.within_round_limit = F::ONE;
339343
let start = calculate_3d_ext_idx(
340344
prod_specs_inner_inner_len,
@@ -410,13 +414,13 @@ where
410414
);
411415
cur_timestamp += 3; // 1 read, 2 writes
412416

413-
let eval = FieldExtension::add(
417+
let eval_rlc = FieldExtension::add(
414418
FieldExtension::multiply(alpha_numerator, p_eval),
415419
FieldExtension::multiply(alpha_denominator, q_eval),
416420
);
417-
logup_specific.eval_rlc = eval;
418-
if mode == NEXT_LAYER_MODE && round < max_round.as_canonical_u32() - 2 {
419-
eval_acc = FieldExtension::add(eval_acc, eval);
421+
logup_specific.eval_rlc = eval_rlc;
422+
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
423+
eval_acc = FieldExtension::add(eval_acc, eval_rlc);
420424
logup_row.should_acc = F::ONE;
421425
logup_row.logup_acc = F::ONE;
422426
logup_row.eval_acc = eval_acc;
@@ -427,6 +431,8 @@ where
427431
}
428432

429433
let head_row = &mut rows[0];
434+
head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1);
435+
430436
let head_specific: &mut HeaderSpecificCols<F> =
431437
head_row.specific[..HeaderSpecificCols::<F>::width()].borrow_mut();
432438

@@ -452,66 +458,79 @@ impl<F: PrimeField32> TraceFiller<F> for NativeSumcheckFiller {
452458
fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
453459
let cols: &mut NativeSumcheckCols<F> = row_slice.borrow_mut();
454460
let start_timestamp = cols.start_timestamp.as_canonical_u32();
461+
let last_timestamp = cols.last_timestamp.as_canonical_u32();
455462

463+
println!("start_timestamp: {}, cols.header_row: {:?}, prod_row: {:?}, logup_row: {:?}", start_timestamp, cols.header_row, cols.prod_row, cols.logup_row);
456464
if cols.header_row == F::ONE {
457465
let header: &mut HeaderSpecificCols<F> =
458466
cols.specific[..HeaderSpecificCols::<F>::width()].borrow_mut();
459467

460468
for i in 0..7usize {
461469
mem_fill_helper(
462470
mem_helper,
463-
start_timestamp + i,
471+
start_timestamp + i as u32,
464472
header.read_records[i].as_mut(),
465473
);
466474
}
467475
mem_fill_helper(
468476
mem_helper,
469-
start_timestamp + 7,
477+
last_timestamp - 1,
470478
header.write_records.as_mut(),
471479
);
472480
} else if cols.prod_row == F::ONE {
473481
let prod_row_specific: &mut ProdSpecificCols<F> =
474482
cols.specific[..ProdSpecificCols::<F>::width()].borrow_mut();
475483

484+
// read max_round
476485
mem_fill_helper(
477486
mem_helper,
478487
start_timestamp,
479488
prod_row_specific.read_records[0].as_mut(),
480489
);
481-
mem_fill_helper(
482-
mem_helper,
483-
start_timestamp + 1,
484-
prod_row_specific.read_records[1].as_mut(),
485-
);
486-
mem_fill_helper(
487-
mem_helper,
488-
start_timestamp + 2,
489-
prod_row_specific.write_record.as_mut(),
490-
);
490+
if cols.within_round_limit == F::ONE {
491+
// read p1, p2
492+
mem_fill_helper(
493+
mem_helper,
494+
start_timestamp + 1,
495+
prod_row_specific.read_records[1].as_mut(),
496+
);
497+
// write p_eval
498+
mem_fill_helper(
499+
mem_helper,
500+
start_timestamp + 2,
501+
prod_row_specific.write_record.as_mut(),
502+
);
503+
}
491504
} else if cols.logup_row == F::ONE {
492505
let logup_row_specific: &mut LogupSpecificCols<F> =
493506
cols.specific[..LogupSpecificCols::<F>::width()].borrow_mut();
494507

508+
// read max_round
495509
mem_fill_helper(
496510
mem_helper,
497511
start_timestamp,
498512
logup_row_specific.read_records[0].as_mut(),
499513
);
500-
mem_fill_helper(
501-
mem_helper,
502-
start_timestamp + 1,
503-
logup_row_specific.read_records[1].as_mut(),
504-
);
505-
mem_fill_helper(
506-
mem_helper,
507-
start_timestamp + 2,
508-
logup_row_specific.write_records[0].as_mut(),
509-
);
510-
mem_fill_helper(
511-
mem_helper,
512-
start_timestamp + 3,
513-
logup_row_specific.write_records[1].as_mut(),
514-
);
514+
if cols.within_round_limit == F::ONE {
515+
// read p1, p2, q1, q2
516+
mem_fill_helper(
517+
mem_helper,
518+
start_timestamp + 1,
519+
logup_row_specific.read_records[1].as_mut(),
520+
);
521+
// write p_eval
522+
mem_fill_helper(
523+
mem_helper,
524+
start_timestamp + 2,
525+
logup_row_specific.write_records[0].as_mut(),
526+
);
527+
// write q_eval
528+
mem_fill_helper(
529+
mem_helper,
530+
start_timestamp + 3,
531+
logup_row_specific.write_records[1].as_mut(),
532+
);
533+
}
515534
}
516535
}
517536
}

extensions/native/circuit/src/sumcheck/execution.rs

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use openvm_stark_backend::p3_field::PrimeField32;
1111
use crate::{
1212
field_extension::{FieldExtension, EXT_DEG},
1313
fri::elem_to_ext,
14-
sumcheck::chip::{calculate_3d_ext_idx, NativeSumcheckExecutor},
14+
sumcheck::chip::{
15+
calculate_3d_ext_idx, NativeSumcheckExecutor, CONTEXT_ARR_BASE_LEN, CURRENT_LAYER_MODE,
16+
NEXT_LAYER_MODE,
17+
},
1518
};
1619

1720
#[derive(AlignedBytesBorrow, Clone)]
@@ -209,7 +212,7 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
209212
.map(|x: F| x.as_canonical_u32());
210213
let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] =
211214
ctx;
212-
let challenges: [F; EXT_DEG * 3] =
215+
let challenges: [F; EXT_DEG * 4] =
213216
exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32());
214217
let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap();
215218
let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap();
@@ -219,9 +222,10 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
219222
let mut alpha_acc = elem_to_ext(F::ONE);
220223
let mut eval_acc = elem_to_ext(F::ZERO);
221224

225+
let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32;
222226
for i in 0..num_prod_spec {
223227
let [max_round]: [u32; 1] = exec_state
224-
.vm_read(NATIVE_AS, ctx_ptr_u32 + 8)
228+
.vm_read(NATIVE_AS, prod_offset + i)
225229
.map(|x: F| x.as_canonical_u32());
226230

227231
let start = calculate_3d_ext_idx(
@@ -238,17 +242,17 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
238242
let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap();
239243

240244
let eval = match mode {
241-
1 => FieldExtension::multiply(p1, p2),
242-
0 => FieldExtension::add(
245+
CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2),
246+
NEXT_LAYER_MODE => FieldExtension::add(
243247
FieldExtension::multiply(p1, c1),
244248
FieldExtension::multiply(p2, c2),
245249
),
246-
_ => unreachable!("mode can only be 0 or 1"),
250+
_ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"),
247251
};
248252

249-
exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + 1 + i, &eval);
253+
exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + (1 + i) * EXT_DEG as u32, &eval);
250254

251-
if round + mode < max_round - 1 {
255+
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
252256
// update eval_acc
253257
eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval));
254258
}
@@ -259,10 +263,11 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
259263
height += 1;
260264
}
261265

266+
let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec;
262267
for i in 0..num_logup_spec {
263268
// read max_round
264269
let [max_round]: [u32; 1] = exec_state
265-
.vm_read(NATIVE_AS, ctx_ptr_u32 + 8 + num_prod_spec + i)
270+
.vm_read(NATIVE_AS, logup_offset + i)
266271
.map(|x: F| x.as_canonical_u32());
267272
let start = calculate_3d_ext_idx(
268273
logup_specs_inner_len,
@@ -272,6 +277,9 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
272277
0,
273278
);
274279

280+
let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha);
281+
let alpha_numerator = alpha_acc;
282+
275283
if round < max_round - 1 {
276284
// read logup_evals
277285
let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start);
@@ -282,23 +290,23 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
282290

283291
// compute p_eval and q_eval
284292
let p_eval = match mode {
285-
1 => FieldExtension::add(
293+
CURRENT_LAYER_MODE => FieldExtension::add(
286294
FieldExtension::multiply(p1, q2),
287295
FieldExtension::multiply(p2, q1),
288296
),
289-
0 => FieldExtension::add(
297+
NEXT_LAYER_MODE => FieldExtension::add(
290298
FieldExtension::multiply(p1, c1),
291299
FieldExtension::multiply(p2, c2),
292300
),
293-
_ => unreachable!("mode can only be 0 or 1"),
301+
_ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"),
294302
};
295303
let q_eval = match mode {
296-
1 => FieldExtension::multiply(q1, q2),
297-
0 => FieldExtension::add(
304+
CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2),
305+
NEXT_LAYER_MODE => FieldExtension::add(
298306
FieldExtension::multiply(q1, c1),
299307
FieldExtension::multiply(q2, c2),
300308
),
301-
_ => unreachable!("mode can only be 0 or 1"),
309+
_ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"),
302310
};
303311

304312
// write eval to r_evals
@@ -313,20 +321,18 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
313321
&q_eval,
314322
);
315323

316-
let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha);
317-
let alpha_numerator = alpha_acc;
318-
319-
if round + mode < max_round - 1 {
324+
let eval_rlc = FieldExtension::add(
325+
FieldExtension::multiply(alpha_numerator, p_eval),
326+
FieldExtension::multiply(alpha_denominator, q_eval),
327+
);
328+
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
320329
// update eval_acc
321-
eval_acc = FieldExtension::add(
322-
FieldExtension::multiply(alpha_numerator, p_eval),
323-
FieldExtension::multiply(alpha_denominator, q_eval),
324-
);
330+
eval_acc = FieldExtension::add(eval_acc, eval_rlc);
325331
}
326332
}
327333

328334
// update alpha_acc
329-
alpha_acc = FieldExtension::multiply(alpha_acc, FieldExtension::multiply(alpha, alpha));
335+
alpha_acc = FieldExtension::multiply(alpha_denominator, alpha);
330336
height += 1;
331337
}
332338

0 commit comments

Comments
 (0)