Skip to content

Commit e31c5cf

Browse files
committed
use spawn-compute in correlated OT
1 parent 832bd23 commit e31c5cf

File tree

2 files changed

+61
-32
lines changed

2 files changed

+61
-32
lines changed

cryprot-ot/src/extension.rs

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -634,33 +634,42 @@ impl<S: Security> CotSender for OtExtensionSender<S> {
634634
) -> Result<(), Self::Error>
635635
where
636636
B: Buf<Block>,
637-
F: FnMut(usize, Block) -> Block + Send,
637+
F: FnMut(usize, Block) -> Block + Send + 'static,
638638
{
639639
let mut r_ots: B::BufKind<[Block; 2]> = B::BufKind::zeroed(ots.len());
640640
self.send_into(&mut r_ots).await?;
641-
let mut send_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
642641
let (mut tx, _) = self.conn.byte_stream().await?;
643642
// TODO benchmark if this should use spawn_compute, I suspect not, since we're
644643
// not doing that much work per byte sent
645-
for (chunk_idx, (ot_chunk, rot_chunk)) in ots
646-
.chunks_mut(send_buf.len())
647-
.zip(r_ots.chunks(send_buf.len()))
648-
.enumerate()
649-
{
650-
for (idx, ((ot, r_ot), correction)) in ot_chunk
651-
.iter_mut()
652-
.zip(rot_chunk)
653-
.zip(&mut send_buf)
644+
let mut ots_owned = mem::take(ots);
645+
let (ch_s, mut ch_r) = tokio::sync::mpsc::channel(5);
646+
let jh = spawn_compute(move || {
647+
for (chunk_idx, (ot_chunk, rot_chunk)) in ots_owned
648+
.chunks_mut(COR_CHUNK_SIZE)
649+
.zip(r_ots.chunks(COR_CHUNK_SIZE))
654650
.enumerate()
655651
{
656-
*ot = r_ot[0];
657-
*correction = r_ot[1] ^ correlation(chunk_idx * COR_CHUNK_SIZE + idx, r_ot[0]);
652+
let mut send_buf: Vec<Block> = Vec::zeroed(ot_chunk.len());
653+
for (idx, ((ot, r_ot), correction)) in ot_chunk
654+
.iter_mut()
655+
.zip(rot_chunk)
656+
.zip(&mut send_buf)
657+
.enumerate()
658+
{
659+
*ot = r_ot[0];
660+
*correction = r_ot[1] ^ correlation(chunk_idx * COR_CHUNK_SIZE + idx, r_ot[0]);
661+
}
662+
ch_s.blocking_send(send_buf).unwrap();
658663
}
659-
tx.write_all(bytemuck::must_cast_slice_mut(
660-
&mut send_buf[..ot_chunk.len()],
661-
))
662-
.await?;
664+
ots_owned
665+
});
666+
667+
while let Some(buf) = ch_r.recv().await {
668+
tx.write_all(bytemuck::must_cast_slice(&buf)).await.unwrap();
663669
}
670+
671+
*ots = jh.await.unwrap();
672+
664673
Ok(())
665674
}
666675
}
@@ -677,21 +686,41 @@ impl<S: Security> CotReceiver for OtExtensionReceiver<S> {
677686
B: Buf<Block>,
678687
{
679688
self.receive_into(ots, choices).await?;
680-
let mut recv_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
681689
let (_, mut rx) = self.conn.byte_stream().await?;
682-
for (ot_chunk, choice_chunk) in ots
683-
.chunks_mut(COR_CHUNK_SIZE)
684-
.zip(choices.chunks(COR_CHUNK_SIZE))
685-
{
686-
rx.read_exact(bytemuck::must_cast_slice_mut(
687-
&mut recv_buf[..ot_chunk.len()],
688-
))
689-
.await?;
690-
for ((ot, correction), choice) in ot_chunk.iter_mut().zip(&recv_buf).zip(choice_chunk) {
691-
let use_correction = Block::conditional_select(&Block::ZERO, &Block::ONES, *choice);
692-
*ot ^= use_correction & *correction;
690+
let (ch_s, mut ch_r) = tokio::sync::mpsc::channel(5);
691+
let batch_block_sizes = iter::repeat_n(COR_CHUNK_SIZE, ots.len() / COR_CHUNK_SIZE);
692+
let last_batch_size = ots.len() % COR_CHUNK_SIZE;
693+
let batch_block_sizes = batch_block_sizes.chain(if last_batch_size != 0 {
694+
Some(last_batch_size)
695+
} else {
696+
None
697+
});
698+
let mut ots_owned = mem::take(ots);
699+
let choices = choices.to_owned();
700+
let jh = spawn_compute(move || {
701+
for (ot_chunk, choice_chunk) in ots_owned
702+
.chunks_mut(COR_CHUNK_SIZE)
703+
.zip(choices.chunks(COR_CHUNK_SIZE))
704+
{
705+
let recv_buf: Vec<Block> = ch_r.blocking_recv().unwrap();
706+
for ((ot, correction), choice) in
707+
ot_chunk.iter_mut().zip(&recv_buf).zip(choice_chunk)
708+
{
709+
let use_correction =
710+
Block::conditional_select(&Block::ZERO, &Block::ONES, *choice);
711+
*ot ^= use_correction & *correction;
712+
}
693713
}
714+
ots_owned
715+
});
716+
717+
for batch_size in batch_block_sizes {
718+
let mut recv_buf: Vec<Block> = Vec::zeroed(batch_size);
719+
rx.read_exact(bytemuck::must_cast_slice_mut(&mut recv_buf))
720+
.await?;
721+
ch_s.send(recv_buf).await.unwrap();
694722
}
723+
*ots = jh.await.unwrap();
695724
Ok(())
696725
}
697726
}
@@ -810,7 +839,7 @@ mod tests {
810839
#[tokio::test]
811840
async fn test_correlated_extension() {
812841
let _g = init_tracing();
813-
const COUNT: usize = 128;
842+
const COUNT: usize = DEFAULT_OT_BATCH_SIZE;
814843
let (c1, c2) = local_conn().await.unwrap();
815844
let rng1 = StdRng::seed_from_u64(42);
816845
let mut rng2 = StdRng::seed_from_u64(24);

cryprot-ot/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ pub trait CotSender: Connected + Send {
187187
correlation: F,
188188
) -> impl Future<Output = Result<Vec<Block>, Self::Error>> + Send
189189
where
190-
F: FnMut(usize, Block) -> Block + Send,
190+
F: FnMut(usize, Block) -> Block + Send + 'static,
191191
{
192192
async move {
193193
let mut ots = Vec::zeroed(count);
@@ -203,7 +203,7 @@ pub trait CotSender: Connected + Send {
203203
) -> impl Future<Output = Result<(), Self::Error>> + Send
204204
where
205205
B: Buf<Block>,
206-
F: FnMut(usize, Block) -> Block + Send;
206+
F: FnMut(usize, Block) -> Block + Send + 'static;
207207
}
208208

209209
pub trait CotReceiver: Connected + Send {

0 commit comments

Comments
 (0)