@@ -634,33 +634,42 @@ impl<S: Security> CotSender for OtExtensionSender<S> {
634634 ) -> Result < ( ) , Self :: Error >
635635 where
636636 B : Buf < Block > ,
637- F : FnMut ( usize , Block ) -> Block + Send ,
637+ F : FnMut ( usize , Block ) -> Block + Send + ' static ,
638638 {
639639 let mut r_ots: B :: BufKind < [ Block ; 2 ] > = B :: BufKind :: zeroed ( ots. len ( ) ) ;
640640 self . send_into ( & mut r_ots) . await ?;
641- let mut send_buf: Vec < Block > = Vec :: zeroed ( COR_CHUNK_SIZE ) ;
642641 let ( mut tx, _) = self . conn . byte_stream ( ) . await ?;
643642 // TODO benchmark if this should use spawn_compute, I suspect not, since we're
644643 // not doing that much work per byte sent
645- for ( chunk_idx, ( ot_chunk, rot_chunk) ) in ots
646- . chunks_mut ( send_buf. len ( ) )
647- . zip ( r_ots. chunks ( send_buf. len ( ) ) )
648- . enumerate ( )
649- {
650- for ( idx, ( ( ot, r_ot) , correction) ) in ot_chunk
651- . iter_mut ( )
652- . zip ( rot_chunk)
653- . zip ( & mut send_buf)
644+ let mut ots_owned = mem:: take ( ots) ;
645+ let ( ch_s, mut ch_r) = tokio:: sync:: mpsc:: channel ( 5 ) ;
646+ let jh = spawn_compute ( move || {
647+ for ( chunk_idx, ( ot_chunk, rot_chunk) ) in ots_owned
648+ . chunks_mut ( COR_CHUNK_SIZE )
649+ . zip ( r_ots. chunks ( COR_CHUNK_SIZE ) )
654650 . enumerate ( )
655651 {
656- * ot = r_ot[ 0 ] ;
657- * correction = r_ot[ 1 ] ^ correlation ( chunk_idx * COR_CHUNK_SIZE + idx, r_ot[ 0 ] ) ;
652+ let mut send_buf: Vec < Block > = Vec :: zeroed ( ot_chunk. len ( ) ) ;
653+ for ( idx, ( ( ot, r_ot) , correction) ) in ot_chunk
654+ . iter_mut ( )
655+ . zip ( rot_chunk)
656+ . zip ( & mut send_buf)
657+ . enumerate ( )
658+ {
659+ * ot = r_ot[ 0 ] ;
660+ * correction = r_ot[ 1 ] ^ correlation ( chunk_idx * COR_CHUNK_SIZE + idx, r_ot[ 0 ] ) ;
661+ }
662+ ch_s. blocking_send ( send_buf) . unwrap ( ) ;
658663 }
659- tx. write_all ( bytemuck:: must_cast_slice_mut (
660- & mut send_buf[ ..ot_chunk. len ( ) ] ,
661- ) )
662- . await ?;
664+ ots_owned
665+ } ) ;
666+
667+ while let Some ( buf) = ch_r. recv ( ) . await {
668+ tx. write_all ( bytemuck:: must_cast_slice ( & buf) ) . await . unwrap ( ) ;
663669 }
670+
671+ * ots = jh. await . unwrap ( ) ;
672+
664673 Ok ( ( ) )
665674 }
666675}
@@ -677,21 +686,41 @@ impl<S: Security> CotReceiver for OtExtensionReceiver<S> {
677686 B : Buf < Block > ,
678687 {
679688 self . receive_into ( ots, choices) . await ?;
680- let mut recv_buf: Vec < Block > = Vec :: zeroed ( COR_CHUNK_SIZE ) ;
681689 let ( _, mut rx) = self . conn . byte_stream ( ) . await ?;
682- for ( ot_chunk, choice_chunk) in ots
683- . chunks_mut ( COR_CHUNK_SIZE )
684- . zip ( choices. chunks ( COR_CHUNK_SIZE ) )
685- {
686- rx. read_exact ( bytemuck:: must_cast_slice_mut (
687- & mut recv_buf[ ..ot_chunk. len ( ) ] ,
688- ) )
689- . await ?;
690- for ( ( ot, correction) , choice) in ot_chunk. iter_mut ( ) . zip ( & recv_buf) . zip ( choice_chunk) {
691- let use_correction = Block :: conditional_select ( & Block :: ZERO , & Block :: ONES , * choice) ;
692- * ot ^= use_correction & * correction;
690+ let ( ch_s, mut ch_r) = tokio:: sync:: mpsc:: channel ( 5 ) ;
691+ let batch_block_sizes = iter:: repeat_n ( COR_CHUNK_SIZE , ots. len ( ) / COR_CHUNK_SIZE ) ;
692+ let last_batch_size = ots. len ( ) % COR_CHUNK_SIZE ;
693+ let batch_block_sizes = batch_block_sizes. chain ( if last_batch_size != 0 {
694+ Some ( last_batch_size)
695+ } else {
696+ None
697+ } ) ;
698+ let mut ots_owned = mem:: take ( ots) ;
699+ let choices = choices. to_owned ( ) ;
700+ let jh = spawn_compute ( move || {
701+ for ( ot_chunk, choice_chunk) in ots_owned
702+ . chunks_mut ( COR_CHUNK_SIZE )
703+ . zip ( choices. chunks ( COR_CHUNK_SIZE ) )
704+ {
705+ let recv_buf: Vec < Block > = ch_r. blocking_recv ( ) . unwrap ( ) ;
706+ for ( ( ot, correction) , choice) in
707+ ot_chunk. iter_mut ( ) . zip ( & recv_buf) . zip ( choice_chunk)
708+ {
709+ let use_correction =
710+ Block :: conditional_select ( & Block :: ZERO , & Block :: ONES , * choice) ;
711+ * ot ^= use_correction & * correction;
712+ }
693713 }
714+ ots_owned
715+ } ) ;
716+
717+ for batch_size in batch_block_sizes {
718+ let mut recv_buf: Vec < Block > = Vec :: zeroed ( batch_size) ;
719+ rx. read_exact ( bytemuck:: must_cast_slice_mut ( & mut recv_buf) )
720+ . await ?;
721+ ch_s. send ( recv_buf) . await . unwrap ( ) ;
694722 }
723+ * ots = jh. await . unwrap ( ) ;
695724 Ok ( ( ) )
696725 }
697726}
@@ -810,7 +839,7 @@ mod tests {
810839 #[ tokio:: test]
811840 async fn test_correlated_extension ( ) {
812841 let _g = init_tracing ( ) ;
813- const COUNT : usize = 128 ;
842+ const COUNT : usize = DEFAULT_OT_BATCH_SIZE ;
814843 let ( c1, c2) = local_conn ( ) . await . unwrap ( ) ;
815844 let rng1 = StdRng :: seed_from_u64 ( 42 ) ;
816845 let mut rng2 = StdRng :: seed_from_u64 ( 24 ) ;
0 commit comments