Skip to content

Commit 055f502

Browse files
committed
misc: refactor MaskRepresentation dynamic allocation
1 parent 38d287f commit 055f502

File tree

4 files changed

+127
-48
lines changed

4 files changed

+127
-48
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ceno_zkvm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ p3.workspace = true
2626
rayon.workspace = true
2727
serde.workspace = true
2828
serde_json.workspace = true
29+
smallvec.workspace = true
2930
sumcheck.workspace = true
3031
transcript.workspace = true
3132
whir.workspace = true

ceno_zkvm/src/precompiles/lookup_keccakf.rs

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use crate::{
4545
instructions::riscv::insn_base::{StateInOut, WriteMEM},
4646
precompiles::{
4747
SelectorTypeLayout,
48-
utils::{MaskRepresentation, not8_expr, set_slice_felts_from_u64 as push_instance},
48+
utils::{Mask, MaskRepresentation, not8_expr, set_slice_felts_from_u64 as push_instance},
4949
},
5050
scheme::utils::gkr_witness,
5151
};
@@ -162,6 +162,29 @@ pub struct KeccakLayout<E: ExtensionField> {
162162
pub n_challenges: usize,
163163
}
164164

165+
const ROTATION_WITNESS_LEN: usize = 196;
166+
const C_TEMP_SPLIT_SIZES: [usize; 8] = [15, 1, 15, 1, 15, 1, 15, 1];
167+
const BYTE_SPLIT_SIZES: [usize; 8] = [8; 8];
168+
169+
#[inline(always)]
170+
fn split_mask_to_bytes(value: u64) -> [u64; 8] {
171+
value.to_le_bytes().map(|b| b as u64)
172+
}
173+
174+
#[inline(always)]
175+
fn split_mask_to_array<const N: usize>(value: u64, sizes: &[usize; N]) -> [u64; N] {
176+
let mut out = [0u64; N];
177+
if N == 8 && sizes.iter().all(|&s| s == 8) {
178+
out.copy_from_slice(&split_mask_to_bytes(value));
179+
return out;
180+
}
181+
let values = MaskRepresentation::from_mask(Mask::new(64, value))
182+
.convert(sizes)
183+
.values();
184+
out.copy_from_slice(values.as_slice());
185+
out
186+
}
187+
165188
impl<E: ExtensionField> KeccakLayout<E> {
166189
fn new(cb: &mut CircuitBuilder<E>, params: KeccakParams) -> Self {
167190
// allocate witnesses, fixed, and eqs
@@ -639,14 +662,6 @@ where
639662

640663
let num_instances = phase1.instances.len();
641664

642-
fn conv64to8(input: u64) -> [u64; 8] {
643-
MaskRepresentation::new(vec![(64, input).into()])
644-
.convert(vec![8; 8])
645-
.values()
646-
.try_into()
647-
.unwrap()
648-
}
649-
650665
// keccak instance full rounds (24 rounds + 8 round padding) as chunk size
651666
// we need to do assignment on respective 31 cyclic group index
652667
wits.values
@@ -729,7 +744,7 @@ where
729744
let mut state8 = [[[0u64; 8]; 5]; 5];
730745
for x in 0..5 {
731746
for y in 0..5 {
732-
state8[x][y] = conv64to8(state64[x][y]);
747+
state8[x][y] = split_mask_to_array(state64[x][y], &BYTE_SPLIT_SIZES);
733748
}
734749
}
735750

@@ -744,14 +759,14 @@ where
744759

745760
for i in 0..5 {
746761
c_aux64[i][0] = state64[0][i];
747-
c_aux8[i][0] = conv64to8(c_aux64[i][0]);
762+
c_aux8[i][0] = split_mask_to_array(c_aux64[i][0], &BYTE_SPLIT_SIZES);
748763
for j in 1..5 {
749764
c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1];
750765
for k in 0..8 {
751766
lk_multiplicity
752767
.lookup_xor_byte(c_aux8[i][j - 1][k], state8[j][i][k]);
753768
}
754-
c_aux8[i][j] = conv64to8(c_aux64[i][j]);
769+
c_aux8[i][j] = split_mask_to_array(c_aux64[i][j], &BYTE_SPLIT_SIZES);
755770
}
756771
}
757772

@@ -760,25 +775,23 @@ where
760775

761776
for x in 0..5 {
762777
c64[x] = c_aux64[x][4];
763-
c8[x] = conv64to8(c64[x]);
778+
c8[x] = split_mask_to_array(c64[x], &BYTE_SPLIT_SIZES);
764779
}
765780

766781
let mut c_temp = [[0u64; 8]; 5];
767782
for i in 0..5 {
768-
let rep = MaskRepresentation::new(vec![(64, c64[i]).into()])
769-
.convert(vec![15, 1, 15, 1, 15, 1, 15, 1])
770-
.values();
771-
for (j, size) in [15, 1, 15, 1, 15, 1, 15, 1].iter().enumerate() {
772-
lk_multiplicity.assert_const_range(rep[j], *size);
783+
let chunks = split_mask_to_array(c64[i], &C_TEMP_SPLIT_SIZES);
784+
for (chunk, size) in chunks.iter().zip(C_TEMP_SPLIT_SIZES.iter()) {
785+
lk_multiplicity.assert_const_range(*chunk, *size);
773786
}
774-
c_temp[i] = rep.try_into().unwrap();
787+
c_temp[i] = chunks;
775788
}
776789

777790
let mut crot64 = [0u64; 5];
778791
let mut crot8 = [[0u64; 8]; 5];
779792
for i in 0..5 {
780793
crot64[i] = c64[i].rotate_left(1);
781-
crot8[i] = conv64to8(crot64[i]);
794+
crot8[i] = split_mask_to_array(crot64[i], &BYTE_SPLIT_SIZES);
782795
}
783796

784797
let mut d64 = [0u64; 5];
@@ -791,30 +804,31 @@ where
791804
crot8[(x + 1) % 5][k],
792805
);
793806
}
794-
d8[x] = conv64to8(d64[x]);
807+
d8[x] = split_mask_to_array(d64[x], &BYTE_SPLIT_SIZES);
795808
}
796809

797810
let mut theta_state64 = state64;
798811
let mut theta_state8 = [[[0u64; 8]; 5]; 5];
799-
let mut rotation_witness = vec![];
812+
let mut rotation_witness = Vec::with_capacity(ROTATION_WITNESS_LEN);
800813

801814
for x in 0..5 {
802815
for y in 0..5 {
803816
theta_state64[y][x] ^= d64[x];
804817
for k in 0..8 {
805818
lk_multiplicity.lookup_xor_byte(state8[y][x][k], d8[x][k])
806819
}
807-
theta_state8[y][x] = conv64to8(theta_state64[y][x]);
820+
theta_state8[y][x] =
821+
split_mask_to_array(theta_state64[y][x], &BYTE_SPLIT_SIZES);
808822

809823
let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]);
810-
let rep =
811-
MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()])
812-
.convert(sizes.clone())
824+
let rotation_chunks =
825+
MaskRepresentation::from_mask(Mask::new(64, theta_state64[y][x]))
826+
.convert(&sizes)
813827
.values();
814-
for (j, size) in sizes.iter().enumerate() {
815-
lk_multiplicity.assert_const_range(rep[j], *size);
828+
for (chunk, size) in rotation_chunks.iter().zip(sizes.iter()) {
829+
lk_multiplicity.assert_const_range(*chunk, *size);
816830
}
817-
rotation_witness.extend(rep);
831+
rotation_witness.extend(rotation_chunks);
818832
}
819833
}
820834
assert_eq!(rotation_witness.len(), rotation_witness_witin.len());
@@ -832,7 +846,8 @@ where
832846

833847
for x in 0..5 {
834848
for y in 0..5 {
835-
rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]);
849+
rhopi_output8[x][y] =
850+
split_mask_to_array(rhopi_output64[x][y], &BYTE_SPLIT_SIZES);
836851
}
837852
}
838853

@@ -849,7 +864,8 @@ where
849864
rhopi_output8[y][(x + 2) % 5][k],
850865
);
851866
}
852-
nonlinear8[y][x] = conv64to8(nonlinear64[y][x]);
867+
nonlinear8[y][x] =
868+
split_mask_to_array(nonlinear64[y][x], &BYTE_SPLIT_SIZES);
853869
}
854870
}
855871

@@ -862,7 +878,8 @@ where
862878
lk_multiplicity
863879
.lookup_xor_byte(rhopi_output8[y][x][k], nonlinear8[y][x][k]);
864880
}
865-
chi_output8[y][x] = conv64to8(chi_output64[y][x]);
881+
chi_output8[y][x] =
882+
split_mask_to_array(chi_output64[y][x], &BYTE_SPLIT_SIZES);
866883
}
867884
}
868885

@@ -873,13 +890,14 @@ where
873890
iota_output64[0][0] ^= RC[round];
874891

875892
for k in 0..8 {
876-
let rc8 = conv64to8(RC[round]);
893+
let rc8 = split_mask_to_array(RC[round], &BYTE_SPLIT_SIZES);
877894
lk_multiplicity.lookup_xor_byte(chi_output8[0][0][k], rc8[k]);
878895
}
879896

880897
for x in 0..5 {
881898
for y in 0..5 {
882-
iota_output8[x][y] = conv64to8(iota_output64[x][y]);
899+
iota_output8[x][y] =
900+
split_mask_to_array(iota_output64[x][y], &BYTE_SPLIT_SIZES);
883901
}
884902
}
885903

@@ -1027,8 +1045,8 @@ pub fn run_lookup_keccakf<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>
10271045

10281046
let span = entered_span!("instances", profiling_2 = true);
10291047
for state in &states {
1030-
let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec());
1031-
let state_mask32 = state_mask64.convert(vec![32; 50]);
1048+
let state_mask64 = MaskRepresentation::from_masks(state.iter().map(|&e| Mask::new(64, e)));
1049+
let state_mask32 = state_mask64.convert(&[32usize; 50]);
10321050

10331051
let instance = KeccakInstance {
10341052
state: KeccakStateInstance {

ceno_zkvm/src/precompiles/utils.rs

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use gkr_iop::circuit_builder::expansion_expr;
33
use itertools::Itertools;
44
use multilinear_extensions::{Expression, ToExpr};
55
use p3::field::FieldAlgebra;
6+
use smallvec::SmallVec;
67

78
pub fn not8_expr<E: ExtensionField>(expr: Expression<E>) -> Expression<E> {
89
E::BaseField::from_canonical_u8(0xFF).expr() - expr
@@ -52,9 +53,11 @@ impl Mask {
5253
}
5354
}
5455

56+
const MASK_INLINE_CAPACITY: usize = 32;
57+
5558
#[derive(Debug, Clone, PartialEq, Eq)]
5659
pub struct MaskRepresentation {
57-
pub rep: Vec<Mask>,
60+
pub rep: SmallVec<[Mask; MASK_INLINE_CAPACITY]>,
5861
}
5962

6063
impl From<Mask> for (usize, u64) {
@@ -78,14 +81,45 @@ impl From<MaskRepresentation> for Vec<(usize, u64)> {
7881
impl From<Vec<(usize, u64)>> for MaskRepresentation {
7982
fn from(tuples: Vec<(usize, u64)>) -> Self {
8083
MaskRepresentation {
81-
rep: tuples.into_iter().map(|tuple| tuple.into()).collect(),
84+
rep: tuples.into_iter().map(Into::into).collect(),
85+
}
86+
}
87+
}
88+
89+
impl FromIterator<(usize, u64)> for MaskRepresentation {
90+
fn from_iter<I: IntoIterator<Item = (usize, u64)>>(iter: I) -> Self {
91+
MaskRepresentation {
92+
rep: iter.into_iter().map(Into::into).collect(),
93+
}
94+
}
95+
}
96+
97+
impl FromIterator<Mask> for MaskRepresentation {
98+
fn from_iter<I: IntoIterator<Item = Mask>>(iter: I) -> Self {
99+
MaskRepresentation {
100+
rep: iter.into_iter().collect(),
82101
}
83102
}
84103
}
85104

86105
impl MaskRepresentation {
87106
pub fn new(masks: Vec<Mask>) -> Self {
88-
Self { rep: masks }
107+
Self { rep: masks.into() }
108+
}
109+
110+
pub fn from_mask(mask: Mask) -> Self {
111+
let mut rep = SmallVec::new();
112+
rep.push(mask);
113+
Self { rep }
114+
}
115+
116+
pub fn from_masks<I>(masks: I) -> Self
117+
where
118+
I: IntoIterator<Item = Mask>,
119+
{
120+
Self {
121+
rep: masks.into_iter().collect(),
122+
}
89123
}
90124

91125
pub fn from_bits(bits: Vec<u64>, sizes: Vec<usize>) -> Self {
@@ -99,7 +133,7 @@ impl MaskRepresentation {
99133
}
100134
masks.push(Mask::new(size, mask));
101135
}
102-
Self { rep: masks }
136+
Self { rep: masks.into() }
103137
}
104138

105139
pub fn to_bits(&self) -> Vec<u64> {
@@ -109,17 +143,42 @@ impl MaskRepresentation {
109143
.collect()
110144
}
111145

112-
pub fn convert(&self, new_sizes: Vec<usize>) -> Self {
113-
let bits = self.to_bits();
114-
Self::from_bits(bits, new_sizes)
146+
pub fn convert(&self, new_sizes: &[usize]) -> Self {
147+
let mut rep = SmallVec::<[Mask; MASK_INLINE_CAPACITY]>::with_capacity(new_sizes.len());
148+
let mut src_index = 0;
149+
let mut src_bit = 0;
150+
for &size in new_sizes {
151+
let mut value = 0u64;
152+
for bit_pos in 0..size {
153+
let mut bit_value = 0u64;
154+
while src_index < self.rep.len() {
155+
let mask = &self.rep[src_index];
156+
if src_bit < mask.size {
157+
bit_value = (mask.value >> src_bit) & 1;
158+
src_bit += 1;
159+
if src_bit == mask.size {
160+
src_index += 1;
161+
src_bit = 0;
162+
}
163+
break;
164+
} else {
165+
src_index += 1;
166+
src_bit = 0;
167+
}
168+
}
169+
value |= bit_value << bit_pos;
170+
}
171+
rep.push(Mask::new(size, value));
172+
}
173+
Self { rep }
115174
}
116175

117-
pub fn values(&self) -> Vec<u64> {
118-
self.rep.iter().map(|m| m.value).collect_vec()
176+
pub fn values(&self) -> SmallVec<[u64; MASK_INLINE_CAPACITY]> {
177+
self.rep.iter().map(|m| m.value).collect()
119178
}
120179

121180
pub fn masks(&self) -> Vec<Mask> {
122-
self.rep.clone()
181+
self.rep.to_vec()
123182
}
124183
}
125184

@@ -150,8 +209,8 @@ mod tests {
150209
let bits = vec![1, 0, 1, 1, 0, 1, 0, 0];
151210
let sizes = vec![3, 5];
152211
let mask_rep = MaskRepresentation::from_bits(bits.clone(), sizes.clone());
153-
let new_sizes = vec![4, 4];
154-
let new_mask_rep = mask_rep.convert(new_sizes);
212+
let new_sizes = [4, 4];
213+
let new_mask_rep = mask_rep.convert(&new_sizes);
155214
assert_eq!(new_mask_rep.rep.len(), 2);
156215
assert_eq!(new_mask_rep.rep[0], Mask::new(4, 0b1101));
157216
assert_eq!(new_mask_rep.rep[1], Mask::new(4, 0b0010));

0 commit comments

Comments
 (0)