Skip to content

Commit cfde48c

Browse files
committed
Add correlated OT
1 parent 59c64fd commit cfde48c

File tree

6 files changed

+379
-54
lines changed

6 files changed

+379
-54
lines changed

cryprot-core/src/buf.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ use crate::alloc::{HugePageMemory, allocate_zeroed_vec};
1414
pub trait Buf<T>:
1515
Default + Debug + Deref<Target = [T]> + DerefMut + Send + Sync + 'static + private::Sealed
1616
{
17+
type BufKind<E>: Buf<E>
18+
where
19+
E: Zeroable + Clone + Default + Debug + Send + Sync + 'static;
1720
/// Create a new `Buf` of length `len` with all elements set to zero.
1821
///
1922
/// Implementations of this directly allocate zeroed memory and do not write
@@ -41,6 +44,11 @@ pub trait Buf<T>:
4144
}
4245

4346
impl<T: Zeroable + Clone + Default + Debug + Send + Sync + 'static> Buf<T> for Vec<T> {
47+
type BufKind<E>
48+
= Vec<E>
49+
where
50+
E: Zeroable + Clone + Default + Debug + Send + Sync + 'static;
51+
4452
fn zeroed(len: usize) -> Self {
4553
allocate_zeroed_vec(len)
4654
}
@@ -66,6 +74,11 @@ impl<T: Zeroable + Clone + Default + Debug + Send + Sync + 'static> Buf<T> for V
6674
}
6775

6876
impl<T: Zeroable + Clone + Default + Debug + Send + Sync + 'static> Buf<T> for HugePageMemory<T> {
77+
type BufKind<E>
78+
= HugePageMemory<E>
79+
where
80+
E: Zeroable + Clone + Default + Debug + Send + Sync + 'static;
81+
6982
fn zeroed(len: usize) -> Self {
7083
HugePageMemory::zeroed(len)
7184
}

cryprot-ot/benches/bench.rs

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
77
use cryprot_core::{Block, alloc::HugePageMemory};
88
use cryprot_net::testing::{init_bench_tracing, local_conn};
99
use cryprot_ot::{
10-
RotReceiver, RotSender,
10+
CotReceiver, CotSender, RotReceiver, RotSender,
1111
base::SimplestOt,
1212
extension::{
1313
MaliciousOtExtensionReceiver, MaliciousOtExtensionSender, SemiHonestOtExtensionReceiver,
@@ -99,7 +99,7 @@ fn criterion_benchmark(c: &mut Criterion) {
9999
}),
100100
tokio::spawn(async move {
101101
receiver
102-
.receive_into(&choices, &mut receiver_ots)
102+
.receive_into(&mut receiver_ots, &choices)
103103
.await
104104
.unwrap();
105105
receiver_ots
@@ -113,64 +113,53 @@ fn criterion_benchmark(c: &mut Criterion) {
113113
})
114114
});
115115

116-
g.bench_function(format!("2 parallel 2**{p} extension OTs"), |b| {
116+
g.bench_function(format!("2**{p} correlated extension OTs"), |b| {
117117
b.to_async(&rt).iter_custom(|iters| {
118118
let mut c11 = c1.sub_connection();
119119
let mut c22 = c2.sub_connection();
120+
120121
async move {
121122
let mut duration = Duration::ZERO;
123+
let mut sender_ots = HugePageMemory::zeroed(count);
124+
let mut receiver_ots = HugePageMemory::zeroed(count);
122125
for _ in 0..iters {
123-
let (
124-
mut sender1,
125-
mut receiver1,
126-
mut sender2,
127-
mut receiver2,
128-
choices1,
129-
choices2,
130-
) = {
126+
// setup not included in duration
127+
let (mut sender, mut receiver, choices) = {
131128
let mut rng1 = StdRng::seed_from_u64(42);
132-
let mut rng2 = StdRng::seed_from_u64(42 * 42);
133-
let choices1 = random_choices(count, &mut rng1);
134-
let choices2 = random_choices(count, &mut rng2);
135-
let mut sender1 = SemiHonestOtExtensionSender::new_with_rng(
136-
c11.sub_connection(),
137-
rng1.clone(),
138-
);
139-
let mut receiver1 = SemiHonestOtExtensionReceiver::new_with_rng(
140-
c22.sub_connection(),
141-
rng2.clone(),
142-
);
143-
144-
let mut sender2 =
129+
let rng2 = StdRng::seed_from_u64(42 * 42);
130+
let choices = random_choices(count, &mut rng1);
131+
let mut sender =
145132
SemiHonestOtExtensionSender::new_with_rng(c11.sub_connection(), rng1);
146-
let mut receiver2 =
133+
let mut receiver =
147134
SemiHonestOtExtensionReceiver::new_with_rng(c22.sub_connection(), rng2);
148-
149-
tokio::try_join!(
150-
sender1.do_base_ots(),
151-
receiver1.do_base_ots(),
152-
sender2.do_base_ots(),
153-
receiver2.do_base_ots()
154-
)
155-
.unwrap();
156-
(sender1, receiver1, sender2, receiver2, choices1, choices2)
135+
tokio::try_join!(sender.do_base_ots(), receiver.do_base_ots()).unwrap();
136+
(sender, receiver, choices)
157137
};
158138
let now = Instant::now();
159-
let jh1 = tokio::spawn(async move { sender1.send(count).await });
160-
let jh2 = tokio::spawn(async move { receiver1.receive(&choices1).await });
161-
let jh3 = tokio::spawn(async move { sender2.send(count).await });
162-
let jh4 = tokio::spawn(async move { receiver2.receive(&choices2).await });
163-
let (ot1, ot2, ot3, ot4) = tokio::try_join!(jh1, jh2, jh3, jh4).unwrap();
139+
(sender_ots, receiver_ots) = tokio::try_join!(
140+
tokio::spawn(async move {
141+
sender
142+
.correlated_send_into(&mut sender_ots, |_| Block::ONES)
143+
.await
144+
.unwrap();
145+
sender_ots
146+
}),
147+
tokio::spawn(async move {
148+
receiver
149+
.correlated_receive_into(&mut receiver_ots, &choices)
150+
.await
151+
.unwrap();
152+
receiver_ots
153+
})
154+
)
155+
.unwrap();
164156
duration += now.elapsed();
165-
ot1.unwrap();
166-
ot2.unwrap();
167-
ot3.unwrap();
168-
ot4.unwrap();
169157
}
170158
duration
171159
}
172160
})
173161
});
162+
174163
g.finish();
175164

176165
let mut g = c.benchmark_group("malicious OT extension");
@@ -206,7 +195,7 @@ fn criterion_benchmark(c: &mut Criterion) {
206195
}),
207196
tokio::spawn(async move {
208197
receiver
209-
.receive_into(&choices, &mut receiver_ots)
198+
.receive_into(&mut receiver_ots, &choices)
210199
.await
211200
.unwrap();
212201
receiver_ots

cryprot-ot/src/adapter.rs

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
//! Adapters for OT types.
22
33
use bitvec::{order::Lsb0, vec::BitVec};
4+
use cryprot_core::{Block, buf::Buf};
5+
use cryprot_net::ConnectionError;
46
use futures::{SinkExt, StreamExt};
7+
use subtle::ConditionallySelectable;
8+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
59

6-
use crate::{Connected, RandChoiceRotReceiver, RandChoiceRotSender, RotReceiver, RotSender};
10+
use crate::{
11+
Connected, CotReceiver, CotSender, Malicious, RandChoiceRotReceiver, RandChoiceRotSender,
12+
RotReceiver, RotSender, SemiHonest,
13+
};
714

815
/// Adapts a [`RandChoiceRotReceiver`] into a [`RotReceiver`] and
916
/// [`RandChoiceRotSender`] into [`RotSender`].
@@ -26,13 +33,18 @@ impl<P: Connected> Connected for ChosenChoice<P> {
2633
}
2734
}
2835

36+
impl<P: SemiHonest> SemiHonest for ChosenChoice<P> {}
37+
38+
// TODO is there something I can cite that this holds?
39+
impl<P: Malicious> Malicious for ChosenChoice<P> {}
40+
2941
impl<R: RandChoiceRotReceiver> RotReceiver for ChosenChoice<R> {
3042
type Error = R::Error;
3143

3244
async fn receive_into(
3345
&mut self,
34-
choices: &[subtle::Choice],
3546
ots: &mut impl cryprot_core::buf::Buf<cryprot_core::Block>,
47+
choices: &[subtle::Choice],
3648
) -> Result<(), Self::Error> {
3749
let mut rand_choices = self
3850
.0
@@ -72,6 +84,171 @@ impl<S: RotSender + RandChoiceRotSender + Send> RotSender for ChosenChoice<S> {
7284
}
7385
}
7486

87+
/// Adapts any [`RotSender`]/[`RotReceiver`] into a
88+
/// [`CotSender`]/[`CotReceiver`].
89+
///
90+
/// This adapter can also be used to easily implement the correlated OT traits
91+
/// on the protocol types directly. Because `&mut S: RotSender` when `S:
92+
/// RotSender` you can create a temporary [`CorrelatedFromRandom`] from a `&mut
93+
/// self` inside an implementation of the correlated traits.
94+
///
95+
/// ```
96+
/// use cryprot_core::{Block, buf::Buf};
97+
///
98+
/// use cryprot_ot::adapter::CorrelatedFromRandom;
99+
/// use cryprot_ot::{Connected, CotSender, RotSender};
100+
///
101+
/// struct MyRotSender;
102+
///
103+
/// # impl Connected for MyRotSender {
104+
/// # fn connection(&mut self) -> &mut cryprot_net::Connection {
105+
/// # todo!()
106+
/// # }
107+
/// # }
108+
///
109+
/// // Error type must implement `From<ConnectionError>` and `From<io::Error>` for
110+
/// // adapter
111+
/// #[derive(thiserror::Error, Debug)]
112+
/// enum Error {
113+
/// #[error("connection")]
114+
/// Connection(#[from] cryprot_net::ConnectionError),
115+
/// #[error("io")]
116+
/// Io(#[from] std::io::Error),
117+
/// }
118+
///
119+
/// impl RotSender for MyRotSender {
120+
/// type Error = Error;
121+
///
122+
/// async fn send_into(
123+
/// &mut self,
124+
/// ots: &mut impl cryprot_core::buf::Buf<[cryprot_core::Block; 2]>,
125+
/// ) -> Result<(), Self::Error> {
126+
/// todo!()
127+
/// }
128+
/// }
129+
///
130+
/// impl CotSender for MyRotSender {
131+
/// type Error = <MyRotSender as RotSender>::Error;
132+
///
133+
/// async fn correlated_send_into<B, F>(
134+
/// &mut self,
135+
/// ots: &mut B,
136+
/// correlation: F,
137+
/// ) -> Result<(), Self::Error>
138+
/// where
139+
/// B: Buf<Block>,
140+
/// F: FnMut(usize) -> Block + Send,
141+
/// {
142+
/// // because &mut self also implements RotSender, we can use it for the adapter
143+
/// CorrelatedFromRandom::new(self)
144+
/// .correlated_send_into(ots, correlation)
145+
/// .await
146+
/// }
147+
/// }
148+
/// ```
149+
#[derive(Debug)]
150+
pub struct CorrelatedFromRandom<P>(P);
151+
152+
impl<P> CorrelatedFromRandom<P> {
153+
pub fn new(protocol: P) -> Self {
154+
Self(protocol)
155+
}
156+
}
157+
158+
impl<P: Connected> Connected for CorrelatedFromRandom<P> {
159+
fn connection(&mut self) -> &mut cryprot_net::Connection {
160+
self.0.connection()
161+
}
162+
}
163+
164+
impl<P: SemiHonest> SemiHonest for CorrelatedFromRandom<P> {}
165+
166+
// TODO is there something I can cite that this holds?
167+
impl<P: Malicious> Malicious for CorrelatedFromRandom<P> {}
168+
169+
// should fit in one jumbo frame
170+
const COR_CHUNK_SIZE: usize = 8500 / Block::BYTES;
171+
172+
impl<S: RotSender> CotSender for CorrelatedFromRandom<S>
173+
where
174+
S::Error: From<ConnectionError> + From<std::io::Error>,
175+
{
176+
type Error = S::Error;
177+
178+
async fn correlated_send_into<B, F>(
179+
&mut self,
180+
ots: &mut B,
181+
mut correlation: F,
182+
) -> Result<(), Self::Error>
183+
where
184+
B: Buf<Block>,
185+
F: FnMut(usize) -> Block + Send,
186+
{
187+
let mut r_ots: B::BufKind<[Block; 2]> = B::BufKind::zeroed(ots.len());
188+
self.0.send_into(&mut r_ots).await?;
189+
let mut send_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
190+
let (mut tx, _) = self.connection().byte_stream().await?;
191+
// Using spawn_compute here results in slightly lower performance.
192+
// I think there is just not enough work done per byte transmitted here.
193+
// This implementation is also simpler and less prone to errors than the
194+
// spawn_compute one.
195+
for (chunk_idx, (ot_chunk, rot_chunk)) in ots
196+
.chunks_mut(send_buf.len())
197+
.zip(r_ots.chunks(send_buf.len()))
198+
.enumerate()
199+
{
200+
for (idx, ((ot, r_ot), correction)) in ot_chunk
201+
.iter_mut()
202+
.zip(rot_chunk)
203+
.zip(&mut send_buf)
204+
.enumerate()
205+
{
206+
*ot = r_ot[0];
207+
*correction = r_ot[1] ^ r_ot[0] ^ correlation(chunk_idx * COR_CHUNK_SIZE + idx);
208+
}
209+
tx.write_all(bytemuck::must_cast_slice_mut(
210+
&mut send_buf[..ot_chunk.len()],
211+
))
212+
.await?;
213+
}
214+
Ok(())
215+
}
216+
}
217+
218+
impl<R: RotReceiver> CotReceiver for CorrelatedFromRandom<R>
219+
where
220+
R::Error: From<ConnectionError> + From<std::io::Error>,
221+
{
222+
type Error = R::Error;
223+
224+
async fn correlated_receive_into<B>(
225+
&mut self,
226+
ots: &mut B,
227+
choices: &[subtle::Choice],
228+
) -> Result<(), Self::Error>
229+
where
230+
B: Buf<Block>,
231+
{
232+
self.0.receive_into(ots, choices).await?;
233+
let mut recv_buf: Vec<Block> = Vec::zeroed(COR_CHUNK_SIZE);
234+
let (_, mut rx) = self.connection().byte_stream().await?;
235+
for (ot_chunk, choice_chunk) in ots
236+
.chunks_mut(COR_CHUNK_SIZE)
237+
.zip(choices.chunks(COR_CHUNK_SIZE))
238+
{
239+
rx.read_exact(bytemuck::must_cast_slice_mut(
240+
&mut recv_buf[..ot_chunk.len()],
241+
))
242+
.await?;
243+
for ((ot, correction), choice) in ot_chunk.iter_mut().zip(&recv_buf).zip(choice_chunk) {
244+
let use_correction = Block::conditional_select(&Block::ZERO, &Block::ONES, *choice);
245+
*ot ^= use_correction & *correction;
246+
}
247+
}
248+
Ok(())
249+
}
250+
}
251+
75252
#[cfg(test)]
76253
mod tests {
77254
use cryprot_net::testing::{init_tracing, local_conn};

cryprot-ot/src/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ impl RotReceiver for SimplestOt {
114114
#[tracing::instrument(target = "cryprot_metrics", level = Level::TRACE, skip_all, fields(phase = phase::BASE_OT))]
115115
async fn receive_into(
116116
&mut self,
117-
choices: &[Choice],
118117
ots: &mut impl Buf<Block>,
118+
choices: &[Choice],
119119
) -> Result<(), Self::Error> {
120120
assert_eq!(choices.len(), ots.len());
121121
let (mut send, mut recv) = self.conn.byte_stream().await?;

0 commit comments

Comments
 (0)