11use std:: mem:: { self } ;
22
3- use bytemuck:: cast_slice_mut;
3+ use bytemuck:: { cast_slice , cast_slice_mut} ;
44use cryprot_core:: { Block , utils:: log2_ceil} ;
55use fastdivide:: DividerU64 ;
66
@@ -13,7 +13,7 @@ pub(crate) struct ExpanderModd {
1313 rng : FastAesRng ,
1414 mod_val : u64 ,
1515 idx : usize ,
16- vals : Box < [ u64 ; RAND_U64_VALS ] > ,
16+ vals : Box < [ Block ; RAND_BLOCKS ] > ,
1717 mod_divider : DividerU64 ,
1818 m_is_pow2 : bool ,
1919 m_pow2_mask : u64 ,
@@ -22,11 +22,12 @@ pub(crate) struct ExpanderModd {
2222
2323impl ExpanderModd {
2424 pub ( crate ) fn new ( seed : Block , m : u64 ) -> Self {
25+ let vals = Box :: new ( [ Block :: ZERO ; RAND_BLOCKS ] ) ;
2526 let mut expander = ExpanderModd {
2627 rng : FastAesRng :: new ( seed) ,
2728 mod_val : 0 ,
2829 idx : 0 ,
29- vals : Box :: new ( [ 0 ; RAND_U64_VALS ] ) ,
30+ vals,
3031 mod_divider : DividerU64 :: divide_by ( 1 ) , // Dummy initial value
3132 m_is_pow2 : false ,
3233 m_pow2_mask : 0 ,
@@ -51,12 +52,14 @@ impl ExpanderModd {
5152
5253 #[ inline( always) ]
5354 pub ( crate ) fn get ( & mut self ) -> usize {
54- if self . idx == self . vals . len ( ) {
55+ // RAND_U64_VALS is equal to len of vals: &[u64]
56+ if self . idx == RAND_U64_VALS {
5557 self . refill ( ) ;
5658 }
57- // SAFETY: self.idx is always < self.vals.len(). If self.idx == self.vals.len(),
58- // it is set to to 0 in self.refill()
59- let val = unsafe { * self . vals . get_unchecked ( self . idx ) } ;
59+ let vals: & [ u64 ] = cast_slice ( & self . vals [ ..] ) ;
60+ // SAFETY: self.idx is always < RAND_U64_VALS == vals.len(). If self.idx ==
61+ // RAND_U64_VALS, it is set to to 0 in self.refill()
62+ let val = unsafe { * vals. get_unchecked ( self . idx ) } ;
6063 self . idx += 1 ;
6164 val as usize
6265 }
@@ -67,14 +70,15 @@ impl ExpanderModd {
6770 self . rng . refill ( ) ;
6871
6972 let src = self . rng . blocks ( ) ;
70- let dest: & mut [ Block ] = cast_slice_mut ( & mut self . vals [ ..] ) ;
73+ let dest: & mut [ Block ] = & mut self . vals [ ..] ;
7174 if self . m_is_pow2 {
7275 for ( dest, src) in dest. iter_mut ( ) . zip ( src) {
7376 * dest = * src & self . m_pow2_mask_blk ;
7477 }
7578 } else {
7679 dest. copy_from_slice ( src) ;
77- for chunk in self . vals . chunks_mut ( 32 ) {
80+ let vals: & mut [ u64 ] = cast_slice_mut ( & mut self . vals [ ..] ) ;
81+ for chunk in vals. chunks_mut ( 32 ) {
7882 Self :: do_mod32 ( chunk, & self . mod_divider , self . mod_val ) ;
7983 }
8084 }
0 commit comments