Skip to content

Commit 097e353

Browse files
authored
Fix: sumcheck unit test failure (#19)
* fix sumcheck unit test * remove hardcoded constants * fix
1 parent 30e9eff commit 097e353

File tree

3 files changed

+120
-321
lines changed

3 files changed

+120
-321
lines changed

extensions/native/circuit/src/sumcheck/chip.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ where
325325
let eval_rlc = FieldExtension::multiply(alpha_acc, eval);
326326
prod_specific.eval_rlc = eval_rlc;
327327

328-
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
328+
let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 };
329+
if round + to_next_round < max_round - 1 {
329330
eval_acc = FieldExtension::add(eval_acc, eval_rlc);
330331
prod_row.should_acc = F::ONE;
331332
prod_row.prod_acc = F::ONE;
@@ -445,7 +446,8 @@ where
445446
FieldExtension::multiply(alpha_denominator, q_eval),
446447
);
447448
logup_specific.eval_rlc = eval_rlc;
448-
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
449+
let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 };
450+
if round + to_next_round < max_round - 1 {
449451
eval_acc = FieldExtension::add(eval_acc, eval_rlc);
450452
logup_row.should_acc = F::ONE;
451453
logup_row.logup_acc = F::ONE;
@@ -480,12 +482,16 @@ where
480482
let specific: &mut ProdSpecificCols<F> =
481483
row.specific[..ProdSpecificCols::<F>::width()].borrow_mut();
482484

483-
eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc);
485+
if row.should_acc == F::ONE {
486+
eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc);
487+
}
484488
row.eval_acc = eval_acc;
485489
} else if row.logup_row == F::ONE {
486490
let specific: &mut LogupSpecificCols<F> =
487491
row.specific[..LogupSpecificCols::<F>::width()].borrow_mut();
488-
eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc);
492+
if row.should_acc == F::ONE {
493+
eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc);
494+
}
489495
row.eval_acc = eval_acc;
490496
}
491497
}

extensions/native/circuit/src/sumcheck/execution.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
252252

253253
exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + (1 + i) * EXT_DEG as u32, &eval);
254254

255-
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
255+
let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 };
256+
if round + to_next_round < max_round - 1 {
256257
// update eval_acc
257258
eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval));
258259
}
@@ -270,8 +271,8 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
270271
.vm_read(NATIVE_AS, logup_offset + i)
271272
.map(|x: F| x.as_canonical_u32());
272273
let start = calculate_3d_ext_idx(
273-
logup_specs_inner_len,
274274
logup_specs_inner_inner_len,
275+
logup_specs_inner_len,
275276
i,
276277
round,
277278
0,
@@ -325,7 +326,8 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait>(
325326
FieldExtension::multiply(alpha_numerator, p_eval),
326327
FieldExtension::multiply(alpha_denominator, q_eval),
327328
);
328-
if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 {
329+
let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 };
330+
if round + to_next_round < max_round - 1 {
329331
// update eval_acc
330332
eval_acc = FieldExtension::add(eval_acc, eval_rlc);
331333
}

0 commit comments

Comments
 (0)