@@ -45,7 +45,7 @@ use crate::{
4545 instructions:: riscv:: insn_base:: { StateInOut , WriteMEM } ,
4646 precompiles:: {
4747 SelectorTypeLayout ,
48- utils:: { MaskRepresentation , not8_expr, set_slice_felts_from_u64 as push_instance} ,
48+ utils:: { Mask , MaskRepresentation , not8_expr, set_slice_felts_from_u64 as push_instance} ,
4949 } ,
5050 scheme:: utils:: gkr_witness,
5151} ;
@@ -162,6 +162,29 @@ pub struct KeccakLayout<E: ExtensionField> {
162162 pub n_challenges : usize ,
163163}
164164
165+ const ROTATION_WITNESS_LEN : usize = 196 ;
166+ const C_TEMP_SPLIT_SIZES : [ usize ; 8 ] = [ 15 , 1 , 15 , 1 , 15 , 1 , 15 , 1 ] ;
167+ const BYTE_SPLIT_SIZES : [ usize ; 8 ] = [ 8 ; 8 ] ;
168+
169+ #[ inline( always) ]
170+ fn split_mask_to_bytes ( value : u64 ) -> [ u64 ; 8 ] {
171+ value. to_le_bytes ( ) . map ( |b| b as u64 )
172+ }
173+
174+ #[ inline( always) ]
175+ fn split_mask_to_array < const N : usize > ( value : u64 , sizes : & [ usize ; N ] ) -> [ u64 ; N ] {
176+ let mut out = [ 0u64 ; N ] ;
177+ if N == 8 && sizes. iter ( ) . all ( |& s| s == 8 ) {
178+ out. copy_from_slice ( & split_mask_to_bytes ( value) ) ;
179+ return out;
180+ }
181+ let values = MaskRepresentation :: from_mask ( Mask :: new ( 64 , value) )
182+ . convert ( sizes)
183+ . values ( ) ;
184+ out. copy_from_slice ( values. as_slice ( ) ) ;
185+ out
186+ }
187+
165188impl < E : ExtensionField > KeccakLayout < E > {
166189 fn new ( cb : & mut CircuitBuilder < E > , params : KeccakParams ) -> Self {
167190 // allocate witnesses, fixed, and eqs
@@ -639,14 +662,6 @@ where
639662
640663 let num_instances = phase1. instances . len ( ) ;
641664
642- fn conv64to8 ( input : u64 ) -> [ u64 ; 8 ] {
643- MaskRepresentation :: new ( vec ! [ ( 64 , input) . into( ) ] )
644- . convert ( vec ! [ 8 ; 8 ] )
645- . values ( )
646- . try_into ( )
647- . unwrap ( )
648- }
649-
650665 // keccak instance full rounds (24 rounds + 8 round padding) as chunk size
651666 // we need to do assignment on respective 31 cyclic group index
652667 wits. values
@@ -729,7 +744,7 @@ where
729744 let mut state8 = [ [ [ 0u64 ; 8 ] ; 5 ] ; 5 ] ;
730745 for x in 0 ..5 {
731746 for y in 0 ..5 {
732- state8[ x] [ y] = conv64to8 ( state64[ x] [ y] ) ;
747+ state8[ x] [ y] = split_mask_to_array ( state64[ x] [ y] , & BYTE_SPLIT_SIZES ) ;
733748 }
734749 }
735750
@@ -744,14 +759,14 @@ where
744759
745760 for i in 0 ..5 {
746761 c_aux64[ i] [ 0 ] = state64[ 0 ] [ i] ;
747- c_aux8[ i] [ 0 ] = conv64to8 ( c_aux64[ i] [ 0 ] ) ;
762+ c_aux8[ i] [ 0 ] = split_mask_to_array ( c_aux64[ i] [ 0 ] , & BYTE_SPLIT_SIZES ) ;
748763 for j in 1 ..5 {
749764 c_aux64[ i] [ j] = state64[ j] [ i] ^ c_aux64[ i] [ j - 1 ] ;
750765 for k in 0 ..8 {
751766 lk_multiplicity
752767 . lookup_xor_byte ( c_aux8[ i] [ j - 1 ] [ k] , state8[ j] [ i] [ k] ) ;
753768 }
754- c_aux8[ i] [ j] = conv64to8 ( c_aux64[ i] [ j] ) ;
769+ c_aux8[ i] [ j] = split_mask_to_array ( c_aux64[ i] [ j] , & BYTE_SPLIT_SIZES ) ;
755770 }
756771 }
757772
@@ -760,25 +775,23 @@ where
760775
761776 for x in 0 ..5 {
762777 c64[ x] = c_aux64[ x] [ 4 ] ;
763- c8[ x] = conv64to8 ( c64[ x] ) ;
778+ c8[ x] = split_mask_to_array ( c64[ x] , & BYTE_SPLIT_SIZES ) ;
764779 }
765780
766781 let mut c_temp = [ [ 0u64 ; 8 ] ; 5 ] ;
767782 for i in 0 ..5 {
768- let rep = MaskRepresentation :: new ( vec ! [ ( 64 , c64[ i] ) . into( ) ] )
769- . convert ( vec ! [ 15 , 1 , 15 , 1 , 15 , 1 , 15 , 1 ] )
770- . values ( ) ;
771- for ( j, size) in [ 15 , 1 , 15 , 1 , 15 , 1 , 15 , 1 ] . iter ( ) . enumerate ( ) {
772- lk_multiplicity. assert_const_range ( rep[ j] , * size) ;
783+ let chunks = split_mask_to_array ( c64[ i] , & C_TEMP_SPLIT_SIZES ) ;
784+ for ( chunk, size) in chunks. iter ( ) . zip ( C_TEMP_SPLIT_SIZES . iter ( ) ) {
785+ lk_multiplicity. assert_const_range ( * chunk, * size) ;
773786 }
774- c_temp[ i] = rep . try_into ( ) . unwrap ( ) ;
787+ c_temp[ i] = chunks ;
775788 }
776789
777790 let mut crot64 = [ 0u64 ; 5 ] ;
778791 let mut crot8 = [ [ 0u64 ; 8 ] ; 5 ] ;
779792 for i in 0 ..5 {
780793 crot64[ i] = c64[ i] . rotate_left ( 1 ) ;
781- crot8[ i] = conv64to8 ( crot64[ i] ) ;
794+ crot8[ i] = split_mask_to_array ( crot64[ i] , & BYTE_SPLIT_SIZES ) ;
782795 }
783796
784797 let mut d64 = [ 0u64 ; 5 ] ;
@@ -791,30 +804,31 @@ where
791804 crot8[ ( x + 1 ) % 5 ] [ k] ,
792805 ) ;
793806 }
794- d8[ x] = conv64to8 ( d64[ x] ) ;
807+ d8[ x] = split_mask_to_array ( d64[ x] , & BYTE_SPLIT_SIZES ) ;
795808 }
796809
797810 let mut theta_state64 = state64;
798811 let mut theta_state8 = [ [ [ 0u64 ; 8 ] ; 5 ] ; 5 ] ;
799- let mut rotation_witness = vec ! [ ] ;
812+ let mut rotation_witness = Vec :: with_capacity ( ROTATION_WITNESS_LEN ) ;
800813
801814 for x in 0 ..5 {
802815 for y in 0 ..5 {
803816 theta_state64[ y] [ x] ^= d64[ x] ;
804817 for k in 0 ..8 {
805818 lk_multiplicity. lookup_xor_byte ( state8[ y] [ x] [ k] , d8[ x] [ k] )
806819 }
807- theta_state8[ y] [ x] = conv64to8 ( theta_state64[ y] [ x] ) ;
820+ theta_state8[ y] [ x] =
821+ split_mask_to_array ( theta_state64[ y] [ x] , & BYTE_SPLIT_SIZES ) ;
808822
809823 let ( sizes, _) = rotation_split ( ROTATION_CONSTANTS [ y] [ x] ) ;
810- let rep =
811- MaskRepresentation :: new ( vec ! [ ( 64 , theta_state64[ y] [ x] ) . into ( ) ] )
812- . convert ( sizes. clone ( ) )
824+ let rotation_chunks =
825+ MaskRepresentation :: from_mask ( Mask :: new ( 64 , theta_state64[ y] [ x] ) )
826+ . convert ( & sizes)
813827 . values ( ) ;
814- for ( j , size) in sizes . iter ( ) . enumerate ( ) {
815- lk_multiplicity. assert_const_range ( rep [ j ] , * size) ;
828+ for ( chunk , size) in rotation_chunks . iter ( ) . zip ( sizes . iter ( ) ) {
829+ lk_multiplicity. assert_const_range ( * chunk , * size) ;
816830 }
817- rotation_witness. extend ( rep ) ;
831+ rotation_witness. extend ( rotation_chunks ) ;
818832 }
819833 }
820834 assert_eq ! ( rotation_witness. len( ) , rotation_witness_witin. len( ) ) ;
@@ -832,7 +846,8 @@ where
832846
833847 for x in 0 ..5 {
834848 for y in 0 ..5 {
835- rhopi_output8[ x] [ y] = conv64to8 ( rhopi_output64[ x] [ y] ) ;
849+ rhopi_output8[ x] [ y] =
850+ split_mask_to_array ( rhopi_output64[ x] [ y] , & BYTE_SPLIT_SIZES ) ;
836851 }
837852 }
838853
@@ -849,7 +864,8 @@ where
849864 rhopi_output8[ y] [ ( x + 2 ) % 5 ] [ k] ,
850865 ) ;
851866 }
852- nonlinear8[ y] [ x] = conv64to8 ( nonlinear64[ y] [ x] ) ;
867+ nonlinear8[ y] [ x] =
868+ split_mask_to_array ( nonlinear64[ y] [ x] , & BYTE_SPLIT_SIZES ) ;
853869 }
854870 }
855871
@@ -862,7 +878,8 @@ where
862878 lk_multiplicity
863879 . lookup_xor_byte ( rhopi_output8[ y] [ x] [ k] , nonlinear8[ y] [ x] [ k] ) ;
864880 }
865- chi_output8[ y] [ x] = conv64to8 ( chi_output64[ y] [ x] ) ;
881+ chi_output8[ y] [ x] =
882+ split_mask_to_array ( chi_output64[ y] [ x] , & BYTE_SPLIT_SIZES ) ;
866883 }
867884 }
868885
@@ -873,13 +890,14 @@ where
873890 iota_output64[ 0 ] [ 0 ] ^= RC [ round] ;
874891
875892 for k in 0 ..8 {
876- let rc8 = conv64to8 ( RC [ round] ) ;
893+ let rc8 = split_mask_to_array ( RC [ round] , & BYTE_SPLIT_SIZES ) ;
877894 lk_multiplicity. lookup_xor_byte ( chi_output8[ 0 ] [ 0 ] [ k] , rc8[ k] ) ;
878895 }
879896
880897 for x in 0 ..5 {
881898 for y in 0 ..5 {
882- iota_output8[ x] [ y] = conv64to8 ( iota_output64[ x] [ y] ) ;
899+ iota_output8[ x] [ y] =
900+ split_mask_to_array ( iota_output64[ x] [ y] , & BYTE_SPLIT_SIZES ) ;
883901 }
884902 }
885903
@@ -1027,8 +1045,8 @@ pub fn run_lookup_keccakf<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>
10271045
10281046 let span = entered_span ! ( "instances" , profiling_2 = true ) ;
10291047 for state in & states {
1030- let state_mask64 = MaskRepresentation :: from ( state. iter ( ) . map ( |e| ( 64 , * e ) ) . collect_vec ( ) ) ;
1031- let state_mask32 = state_mask64. convert ( vec ! [ 32 ; 50 ] ) ;
1048+ let state_mask64 = MaskRepresentation :: from_masks ( state. iter ( ) . map ( |& e| Mask :: new ( 64 , e ) ) ) ;
1049+ let state_mask32 = state_mask64. convert ( & [ 32usize ; 50 ] ) ;
10321050
10331051 let instance = KeccakInstance {
10341052 state : KeccakStateInstance {
0 commit comments