Skip to content

Commit 183f6a5

Browse files
committed
Add correlated OT
1 parent fa2a76b commit 183f6a5

File tree

6 files changed

+198
-60
lines changed

6 files changed

+198
-60
lines changed

cryprot-core/src/buf.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::alloc::{HugePageMemory, allocate_zeroed_vec};
1414
pub trait Buf<T>:
1515
Default + Debug + Deref<Target = [T]> + DerefMut + Send + Sync + 'static + private::Sealed
1616
{
17+
type BufKind<E>: Buf<E> where E: Zeroable + Clone + Default + Debug + Send + Sync + 'static;
1718
/// Create a new `Buf` of length `len` with all elements set to zero.
1819
///
1920
/// Implementations of this directly allocate zeroed memory and do not write
@@ -41,6 +42,8 @@ pub trait Buf<T>:
4142
}
4243

4344
impl<T: Zeroable + Clone + Default + Debug + Send + Sync + 'static> Buf<T> for Vec<T> {
45+
type BufKind<E> = Vec<E> where E: Zeroable + Clone + Default + Debug + Send + Sync + 'static;
46+
4447
fn zeroed(len: usize) -> Self {
4548
allocate_zeroed_vec(len)
4649
}
@@ -62,10 +65,12 @@ impl<T: Zeroable + Clone + Default + Debug + Send + Sync + 'static> Buf<T> for V
6265

6366
fn grow_zeroed(&mut self, new_size: usize) {
6467
self.resize(new_size, T::zeroed());
65-
}
68+
}
6669
}
6770

6871
impl<T: Zeroable + Clone + Default + Debug + Send + Sync + 'static> Buf<T> for HugePageMemory<T> {
72+
type BufKind<E> = HugePageMemory<E> where E: Zeroable + Clone + Default + Debug + Send + Sync + 'static;
73+
6974
fn zeroed(len: usize) -> Self {
7075
HugePageMemory::zeroed(len)
7176
}

cryprot-ot/benches/bench.rs

Lines changed: 32 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@ use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
77
use cryprot_core::{Block, alloc::HugePageMemory};
88
use cryprot_net::testing::{init_bench_tracing, local_conn};
99
use cryprot_ot::{
10-
RotReceiver, RotSender,
11-
base::SimplestOt,
12-
extension::{
10+
base::SimplestOt, extension::{
1311
MaliciousOtExtensionReceiver, MaliciousOtExtensionSender, SemiHonestOtExtensionReceiver,
1412
SemiHonestOtExtensionSender,
15-
},
16-
random_choices,
17-
silent_ot::{
13+
}, random_choices, silent_ot::{
1814
MaliciousSilentOtReceiver, MaliciousSilentOtSender, SemiHonestSilentOtReceiver,
1915
SemiHonestSilentOtSender,
20-
},
16+
}, CotReceiver, CotSender, RotReceiver, RotSender
2117
};
2218
use rand::{SeedableRng, rngs::StdRng};
2319
use tokio::runtime::{self, Runtime};
@@ -99,7 +95,7 @@ fn criterion_benchmark(c: &mut Criterion) {
9995
}),
10096
tokio::spawn(async move {
10197
receiver
102-
.receive_into(&choices, &mut receiver_ots)
98+
.receive_into(&mut receiver_ots, &choices)
10399
.await
104100
.unwrap();
105101
receiver_ots
@@ -113,64 +109,50 @@ fn criterion_benchmark(c: &mut Criterion) {
113109
})
114110
});
115111

116-
g.bench_function(format!("2 parallel 2**{p} extension OTs"), |b| {
112+
g.bench_function(format!("2**{p} correlated extension OTs"), |b| {
117113
b.to_async(&rt).iter_custom(|iters| {
118114
let mut c11 = c1.sub_connection();
119115
let mut c22 = c2.sub_connection();
116+
120117
async move {
121118
let mut duration = Duration::ZERO;
119+
let mut sender_ots = HugePageMemory::zeroed(count);
120+
let mut receiver_ots = HugePageMemory::zeroed(count);
122121
for _ in 0..iters {
123-
let (
124-
mut sender1,
125-
mut receiver1,
126-
mut sender2,
127-
mut receiver2,
128-
choices1,
129-
choices2,
130-
) = {
122+
// setup not included in duration
123+
let (mut sender, mut receiver, choices) = {
131124
let mut rng1 = StdRng::seed_from_u64(42);
132-
let mut rng2 = StdRng::seed_from_u64(42 * 42);
133-
let choices1 = random_choices(count, &mut rng1);
134-
let choices2 = random_choices(count, &mut rng2);
135-
let mut sender1 = SemiHonestOtExtensionSender::new_with_rng(
136-
c11.sub_connection(),
137-
rng1.clone(),
138-
);
139-
let mut receiver1 = SemiHonestOtExtensionReceiver::new_with_rng(
140-
c22.sub_connection(),
141-
rng2.clone(),
142-
);
143-
144-
let mut sender2 =
125+
let rng2 = StdRng::seed_from_u64(42 * 42);
126+
let choices = random_choices(count, &mut rng1);
127+
let mut sender =
145128
SemiHonestOtExtensionSender::new_with_rng(c11.sub_connection(), rng1);
146-
let mut receiver2 =
129+
let mut receiver =
147130
SemiHonestOtExtensionReceiver::new_with_rng(c22.sub_connection(), rng2);
148-
149-
tokio::try_join!(
150-
sender1.do_base_ots(),
151-
receiver1.do_base_ots(),
152-
sender2.do_base_ots(),
153-
receiver2.do_base_ots()
154-
)
155-
.unwrap();
156-
(sender1, receiver1, sender2, receiver2, choices1, choices2)
131+
tokio::try_join!(sender.do_base_ots(), receiver.do_base_ots()).unwrap();
132+
(sender, receiver, choices)
157133
};
158134
let now = Instant::now();
159-
let jh1 = tokio::spawn(async move { sender1.send(count).await });
160-
let jh2 = tokio::spawn(async move { receiver1.receive(&choices1).await });
161-
let jh3 = tokio::spawn(async move { sender2.send(count).await });
162-
let jh4 = tokio::spawn(async move { receiver2.receive(&choices2).await });
163-
let (ot1, ot2, ot3, ot4) = tokio::try_join!(jh1, jh2, jh3, jh4).unwrap();
135+
(sender_ots, receiver_ots) = tokio::try_join!(
136+
tokio::spawn(async move {
137+
sender.correlated_send_into(&mut sender_ots, |_, b| b ^ Block::ONES).await.unwrap();
138+
sender_ots
139+
}),
140+
tokio::spawn(async move {
141+
receiver
142+
.correlated_receive_into(&mut receiver_ots, &choices)
143+
.await
144+
.unwrap();
145+
receiver_ots
146+
})
147+
)
148+
.unwrap();
164149
duration += now.elapsed();
165-
ot1.unwrap();
166-
ot2.unwrap();
167-
ot3.unwrap();
168-
ot4.unwrap();
169150
}
170151
duration
171152
}
172153
})
173154
});
155+
174156
g.finish();
175157

176158
let mut g = c.benchmark_group("malicious OT extension");
@@ -206,7 +188,7 @@ fn criterion_benchmark(c: &mut Criterion) {
206188
}),
207189
tokio::spawn(async move {
208190
receiver
209-
.receive_into(&choices, &mut receiver_ots)
191+
.receive_into(&mut receiver_ots, &choices)
210192
.await
211193
.unwrap();
212194
receiver_ots

cryprot-ot/src/adapter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ impl<R: RandChoiceRotReceiver> RotReceiver for ChosenChoice<R> {
3131

3232
async fn receive_into(
3333
&mut self,
34-
choices: &[subtle::Choice],
3534
ots: &mut impl cryprot_core::buf::Buf<cryprot_core::Block>,
35+
choices: &[subtle::Choice],
3636
) -> Result<(), Self::Error> {
3737
let mut rand_choices = self
3838
.0

cryprot-ot/src/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ impl RotReceiver for SimplestOt {
114114
#[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
115115
async fn receive_into(
116116
&mut self,
117-
choices: &[Choice],
118117
ots: &mut impl Buf<Block>,
118+
choices: &[Choice],
119119
) -> Result<(), Self::Error> {
120120
assert_eq!(choices.len(), ots.len());
121121
let (mut send, mut recv) = self.conn.byte_stream().await?;

cryprot-ot/src/extension.rs

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ use tokio::{
3636
use tracing::Level;
3737

3838
use crate::{
39-
Connected, Malicious, MaliciousMarker, RotReceiver, RotSender, Security, SemiHonest,
40-
SemiHonestMarker,
39+
Connected, CotReceiver, CotSender, Malicious, MaliciousMarker, RotReceiver, RotSender,
40+
Security, SemiHonest, SemiHonestMarker,
4141
base::{self, SimplestOt},
4242
phase, random_choices,
4343
};
@@ -448,8 +448,8 @@ impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
448448
#[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::OT_EXTENSION))]
449449
async fn receive_into(
450450
&mut self,
451-
choices: &[Choice],
452451
ots: &mut impl Buf<Block>,
452+
choices: &[Choice],
453453
) -> Result<(), Self::Error> {
454454
assert_eq!(choices.len(), ots.len());
455455
assert_eq!(
@@ -621,6 +621,81 @@ impl<S: Security> RotReceiver for OtExtensionReceiver<S> {
621621
}
622622
}
623623

624+
// should fit in one jumbo frame
625+
const COR_CHUNK_SIZE: usize = 8500 / Block::BYTES;
626+
627+
impl<S: Security> CotSender for OtExtensionSender<S> {
628+
type Error = Error;
629+
630+
async fn correlated_send_into<B, F>(
631+
&mut self,
632+
ots: &mut B,
633+
mut correlation: F,
634+
) -> Result<(), Self::Error>
635+
where
636+
B: Buf<Block>,
637+
F: FnMut(usize, Block) -> Block + Send,
638+
{
639+
let mut r_ots: B::BufKind<[Block; 2]> = B::BufKind::zeroed(ots.len());
640+
self.send_into(&mut r_ots).await?;
641+
let mut send_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
642+
let (mut tx, _) = self.conn.byte_stream().await?;
643+
// TODO benchmark if this should use spawn_compute, I suspect not, since we're
644+
// not doing that much work per byte sent
645+
for (chunk_idx, (ot_chunk, rot_chunk)) in ots
646+
.chunks_mut(send_buf.len())
647+
.zip(r_ots.chunks(send_buf.len()))
648+
.enumerate()
649+
{
650+
for (idx, ((ot, r_ot), correction)) in ot_chunk
651+
.iter_mut()
652+
.zip(rot_chunk)
653+
.zip(&mut send_buf)
654+
.enumerate()
655+
{
656+
*ot = r_ot[0];
657+
*correction = r_ot[1] ^ correlation(chunk_idx * COR_CHUNK_SIZE + idx, r_ot[0]);
658+
}
659+
tx.write_all(bytemuck::must_cast_slice_mut(
660+
&mut send_buf[..ot_chunk.len()],
661+
))
662+
.await?;
663+
}
664+
Ok(())
665+
}
666+
}
667+
668+
impl<S: Security> CotReceiver for OtExtensionReceiver<S> {
669+
type Error = Error;
670+
671+
async fn correlated_receive_into<B>(
672+
&mut self,
673+
ots: &mut B,
674+
choices: &[Choice],
675+
) -> Result<(), Self::Error>
676+
where
677+
B: Buf<Block>,
678+
{
679+
self.receive_into(ots, choices).await?;
680+
let mut recv_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
681+
let (_, mut rx) = self.conn.byte_stream().await?;
682+
for (ot_chunk, choice_chunk) in ots
683+
.chunks_mut(COR_CHUNK_SIZE)
684+
.zip(choices.chunks(COR_CHUNK_SIZE))
685+
{
686+
rx.read_exact(bytemuck::must_cast_slice_mut(
687+
&mut recv_buf[..ot_chunk.len()],
688+
))
689+
.await?;
690+
for ((ot, correction), choice) in ot_chunk.iter_mut().zip(&recv_buf).zip(choice_chunk) {
691+
let use_correction = Block::conditional_select(&Block::ZERO, &Block::ONES, *choice);
692+
*ot ^= use_correction & *correction;
693+
}
694+
}
695+
Ok(())
696+
}
697+
}
698+
624699
fn commit(b: Block) -> random_oracle::Hash {
625700
random_oracle::hash(b.as_bytes())
626701
}
@@ -651,11 +726,12 @@ impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
651726
#[cfg(test)]
652727
mod tests {
653728

729+
use cryprot_core::Block;
654730
use cryprot_net::testing::{init_tracing, local_conn};
655731
use rand::{SeedableRng, rngs::StdRng};
656732

657733
use crate::{
658-
MaliciousMarker, RotReceiver, RotSender,
734+
CotReceiver, CotSender, MaliciousMarker, RotReceiver, RotSender,
659735
extension::{
660736
DEFAULT_OT_BATCH_SIZE, OtExtensionReceiver, OtExtensionSender,
661737
SemiHonestOtExtensionReceiver, SemiHonestOtExtensionSender,
@@ -730,4 +806,28 @@ mod tests {
730806
assert_eq!(r, s[c.unwrap_u8() as usize]);
731807
}
732808
}
809+
810+
#[tokio::test]
811+
async fn test_correlated_extension() {
812+
let _g = init_tracing();
813+
const COUNT: usize = 128;
814+
let (c1, c2) = local_conn().await.unwrap();
815+
let rng1 = StdRng::seed_from_u64(42);
816+
let mut rng2 = StdRng::seed_from_u64(24);
817+
let choices = random_choices(COUNT, &mut rng2);
818+
let mut sender = SemiHonestOtExtensionSender::new_with_rng(c1, rng1);
819+
let mut receiver = SemiHonestOtExtensionReceiver::new_with_rng(c2, rng2);
820+
let (send_ots, recv_ots) = tokio::try_join!(
821+
sender.correlated_send(COUNT, |_, b| b ^ Block::ONES),
822+
receiver.correlated_receive(&choices)
823+
)
824+
.unwrap();
825+
for (i, ((r, s), c)) in recv_ots.into_iter().zip(send_ots).zip(choices).enumerate() {
826+
if bool::from(c) {
827+
assert_eq!(r ^ Block::ONES, s, "Block {i}");
828+
} else {
829+
assert_eq!(r, s, "Block {i}")
830+
}
831+
}
832+
}
733833
}

0 commit comments

Comments
 (0)