@@ -36,8 +36,8 @@ use tokio::{
3636use tracing:: Level ;
3737
3838use crate :: {
39- Connected , Malicious , MaliciousMarker , RotReceiver , RotSender , Security , SemiHonest ,
40- SemiHonestMarker ,
39+ Connected , CotReceiver , CotSender , Malicious , MaliciousMarker , RotReceiver , RotSender ,
40+ Security , SemiHonest , SemiHonestMarker ,
4141 base:: { self , SimplestOt } ,
4242 phase, random_choices,
4343} ;
@@ -448,8 +448,8 @@ impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
448448 #[ tracing:: instrument( target = "cryprot_metrics" , level = Level :: TRACE , skip_all, fields( phase = phase:: OT_EXTENSION ) ) ]
449449 async fn receive_into (
450450 & mut self ,
451- choices : & [ Choice ] ,
452451 ots : & mut impl Buf < Block > ,
452+ choices : & [ Choice ] ,
453453 ) -> Result < ( ) , Self :: Error > {
454454 assert_eq ! ( choices. len( ) , ots. len( ) ) ;
455455 assert_eq ! (
@@ -621,6 +621,81 @@ impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
621621 }
622622}
623623
624+ // should fit in one jumbo frame
625+ const COR_CHUNK_SIZE : usize = 8500 / Block :: BYTES ;
626+
627+ impl < S : Security > CotSender for OtExtensionSender < S > {
628+ type Error = Error ;
629+
630+ async fn correlated_send_into < B , F > (
631+ & mut self ,
632+ ots : & mut B ,
633+ mut correlation : F ,
634+ ) -> Result < ( ) , Self :: Error >
635+ where
636+ B : Buf < Block > ,
637+ F : FnMut ( usize , Block ) -> Block + Send ,
638+ {
639+ let mut r_ots: B :: BufKind < [ Block ; 2 ] > = B :: BufKind :: zeroed ( ots. len ( ) ) ;
640+ self . send_into ( & mut r_ots) . await ?;
641+ let mut send_buf: Vec < Block > = Vec :: zeroed ( COR_CHUNK_SIZE ) ;
642+ let ( mut tx, _) = self . conn . byte_stream ( ) . await ?;
643+ // TODO benchmark if this should use spawn_compute, I suspect not, since we're
644+ // not doing that much work per byte sent
645+ for ( chunk_idx, ( ot_chunk, rot_chunk) ) in ots
646+ . chunks_mut ( send_buf. len ( ) )
647+ . zip ( r_ots. chunks ( send_buf. len ( ) ) )
648+ . enumerate ( )
649+ {
650+ for ( idx, ( ( ot, r_ot) , correction) ) in ot_chunk
651+ . iter_mut ( )
652+ . zip ( rot_chunk)
653+ . zip ( & mut send_buf)
654+ . enumerate ( )
655+ {
656+ * ot = r_ot[ 0 ] ;
657+ * correction = r_ot[ 1 ] ^ correlation ( chunk_idx * COR_CHUNK_SIZE + idx, r_ot[ 0 ] ) ;
658+ }
659+ tx. write_all ( bytemuck:: must_cast_slice_mut (
660+ & mut send_buf[ ..ot_chunk. len ( ) ] ,
661+ ) )
662+ . await ?;
663+ }
664+ Ok ( ( ) )
665+ }
666+ }
667+
668+ impl < S : Security > CotReceiver for OtExtensionReceiver < S > {
669+ type Error = Error ;
670+
671+ async fn correlated_receive_into < B > (
672+ & mut self ,
673+ ots : & mut B ,
674+ choices : & [ Choice ] ,
675+ ) -> Result < ( ) , Self :: Error >
676+ where
677+ B : Buf < Block > ,
678+ {
679+ self . receive_into ( ots, choices) . await ?;
680+ let mut recv_buf: Vec < Block > = Vec :: zeroed ( COR_CHUNK_SIZE ) ;
681+ let ( _, mut rx) = self . conn . byte_stream ( ) . await ?;
682+ for ( ot_chunk, choice_chunk) in ots
683+ . chunks_mut ( COR_CHUNK_SIZE )
684+ . zip ( choices. chunks ( COR_CHUNK_SIZE ) )
685+ {
686+ rx. read_exact ( bytemuck:: must_cast_slice_mut (
687+ & mut recv_buf[ ..ot_chunk. len ( ) ] ,
688+ ) )
689+ . await ?;
690+ for ( ( ot, correction) , choice) in ot_chunk. iter_mut ( ) . zip ( & recv_buf) . zip ( choice_chunk) {
691+ let use_correction = Block :: conditional_select ( & Block :: ZERO , & Block :: ONES , * choice) ;
692+ * ot ^= use_correction & * correction;
693+ }
694+ }
695+ Ok ( ( ) )
696+ }
697+ }
698+
624699fn commit ( b : Block ) -> random_oracle:: Hash {
625700 random_oracle:: hash ( b. as_bytes ( ) )
626701}
@@ -651,11 +726,12 @@ impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
651726#[ cfg( test) ]
652727mod tests {
653728
729+ use cryprot_core:: Block ;
654730 use cryprot_net:: testing:: { init_tracing, local_conn} ;
655731 use rand:: { SeedableRng , rngs:: StdRng } ;
656732
657733 use crate :: {
658- MaliciousMarker , RotReceiver , RotSender ,
734+ CotReceiver , CotSender , MaliciousMarker , RotReceiver , RotSender ,
659735 extension:: {
660736 DEFAULT_OT_BATCH_SIZE , OtExtensionReceiver , OtExtensionSender ,
661737 SemiHonestOtExtensionReceiver , SemiHonestOtExtensionSender ,
@@ -730,4 +806,28 @@ mod tests {
730806 assert_eq ! ( r, s[ c. unwrap_u8( ) as usize ] ) ;
731807 }
732808 }
809+
810+ #[ tokio:: test]
811+ async fn test_correlated_extension ( ) {
812+ let _g = init_tracing ( ) ;
813+ const COUNT : usize = 128 ;
814+ let ( c1, c2) = local_conn ( ) . await . unwrap ( ) ;
815+ let rng1 = StdRng :: seed_from_u64 ( 42 ) ;
816+ let mut rng2 = StdRng :: seed_from_u64 ( 24 ) ;
817+ let choices = random_choices ( COUNT , & mut rng2) ;
818+ let mut sender = SemiHonestOtExtensionSender :: new_with_rng ( c1, rng1) ;
819+ let mut receiver = SemiHonestOtExtensionReceiver :: new_with_rng ( c2, rng2) ;
820+ let ( send_ots, recv_ots) = tokio:: try_join!(
821+ sender. correlated_send( COUNT , |_, b| b ^ Block :: ONES ) ,
822+ receiver. correlated_receive( & choices)
823+ )
824+ . unwrap ( ) ;
825+ for ( i, ( ( r, s) , c) ) in recv_ots. into_iter ( ) . zip ( send_ots) . zip ( choices) . enumerate ( ) {
826+ if bool:: from ( c) {
827+ assert_eq ! ( r ^ Block :: ONES , s, "Block {i}" ) ;
828+ } else {
829+ assert_eq ! ( r, s, "Block {i}" )
830+ }
831+ }
832+ }
733833}
0 commit comments