Skip to content

Commit b68a273

Browse files
committed
Implement actual decryption
1 parent 65a1d51 commit b68a273

File tree

3 files changed

+192
-33
lines changed

3 files changed

+192
-33
lines changed

src/key_exchange.rs

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use std::str;
22

33
use aws_lc_rs::{
44
agreement::{self, EphemeralPrivateKey, UnparsedPublicKey, X25519},
5-
digest,
5+
cipher::{StreamingDecryptingKey, UnboundCipherKey, AES_128},
6+
digest, hmac,
67
rand::{self, SystemRandom},
78
signature::KeyPair,
89
};
@@ -117,7 +118,36 @@ impl EcdhKeyExchange {
117118
return Err(());
118119
}
119120

120-
// FIXME wait for and send newkey packet
121+
let packet = match conn.stream_read.read_packet().await {
122+
Ok(packet) => packet,
123+
Err(error) => {
124+
warn!(addr = %conn.addr, %error, "failed to read packet");
125+
return Err(());
126+
}
127+
};
128+
let Decoded {
129+
value: r#type,
130+
next: _,
131+
} = MessageType::decode(packet.payload)
132+
.map_err(|error| warn!(addr = %conn.addr, %error, "failed to read packet type"))?;
133+
if r#type != MessageType::NewKeys {
134+
warn!(addr = %conn.addr, "unexpected message type {:?}", r#type);
135+
return Err(());
136+
}
137+
138+
conn.write_buf.clear();
139+
let Ok(packet) = Packet::builder(&mut conn.write_buf)
140+
.with_payload(&MessageType::NewKeys)
141+
.without_mac()
142+
else {
143+
error!(addr = %conn.addr, "failed to build newkeys packet");
144+
return Err(());
145+
};
146+
147+
if let Err(error) = conn.stream_write.write_all(packet).await {
148+
warn!(addr = %conn.addr, %error, "failed to send newkeys packet");
149+
return Err(());
150+
}
121151

122152
// The first exchange hash is used as session id.
123153
let session_id = self.session_id.as_ref().unwrap_or(&exchange_hash);
@@ -126,12 +156,31 @@ impl EcdhKeyExchange {
126156
exchange_hash,
127157
session_id,
128158
};
129-
#[expect(clippy::unnecessary_operation)]
130-
RawKeySet {
159+
let raw_keys = RawKeySet {
131160
client_to_server: RawKeys::client_to_server(&derivation),
132161
server_to_client: RawKeys::server_to_client(&derivation),
133162
};
134163

164+
conn.stream_read.set_decryption_key(
165+
StreamingDecryptingKey::ctr(
166+
UnboundCipherKey::new(
167+
&AES_128,
168+
&raw_keys.client_to_server.encryption_key.as_ref()[..16],
169+
)
170+
.unwrap(),
171+
aws_lc_rs::cipher::DecryptionContext::Iv128(
172+
raw_keys.client_to_server.initial_iv.as_ref()[..16]
173+
.try_into()
174+
.unwrap(),
175+
),
176+
)
177+
.unwrap(),
178+
hmac::Key::new(
179+
hmac::HMAC_SHA256,
180+
&raw_keys.client_to_server.integrity_key.as_ref()[..32],
181+
),
182+
);
183+
135184
Ok(())
136185
}
137186
}
@@ -555,7 +604,6 @@ struct RawKeySet {
555604
server_to_client: RawKeys,
556605
}
557606

558-
#[expect(dead_code)] // FIXME implement encryption/decryption and MAC
559607
struct RawKeys {
560608
initial_iv: digest::Digest,
561609
encryption_key: digest::Digest,

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async fn main() -> anyhow::Result<()> {
5959
Ok((stream, addr)) => {
6060
debug!(%addr, "accepted connection");
6161
let conn = Connection::new(stream, addr, host_key.clone())?;
62-
tokio::spawn(conn.run());
62+
conn.run().await; // FIXME(aws/aws-lc-rs#975) use tokio::spawn() once StreamingDecryptingKey is Send
6363
}
6464
Err(error) => {
6565
warn!(%error, "failed to accept connection");

src/proto.rs

Lines changed: 138 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use core::iter;
22
use std::io;
33

4-
use aws_lc_rs::rand;
4+
use aws_lc_rs::{cipher::StreamingDecryptingKey, constant_time, hmac, rand};
55
use tokio::io::AsyncReadExt;
66
use tracing::debug;
77

@@ -75,25 +75,36 @@ impl From<u8> for MessageType {
7575
/// v
7676
/// |read|unread and not yet decrypted|
7777
/// ```
78-
// FIXME implement actual decryption
7978
pub(crate) struct DecryptingReader<R: AsyncReadExt + Unpin> {
8079
stream: R,
8180
buf: Vec<u8>,
81+
decrypted_buf: Vec<u8>,
8282
unread_start: usize,
83+
84+
packet_number: u32,
85+
decryption_key: Option<(StreamingDecryptingKey, hmac::Key)>,
8386
}
8487

8588
impl<R: AsyncReadExt + Unpin> DecryptingReader<R> {
8689
pub(crate) fn new(stream: R) -> Self {
8790
Self {
8891
stream,
8992
buf: Vec::with_capacity(16_384),
93+
decrypted_buf: Vec::with_capacity(16_384),
9094
unread_start: 0,
95+
packet_number: 0,
96+
decryption_key: None,
9197
}
9298
}
9399

94-
async fn ensure_at_least(&mut self, n: u32) -> Result<(), Error> {
95-
while self.buf.len() - self.unread_start < n as usize {
96-
let read = self.stream.read_buf(&mut self.buf).await?;
100+
async fn ensure_at_least(
101+
stream: &mut R,
102+
buf: &mut Vec<u8>,
103+
unread_start: &mut usize,
104+
n: u32,
105+
) -> Result<(), Error> {
106+
while buf.len() - *unread_start < n as usize {
107+
let read = stream.read_buf(buf).await?;
97108
debug!(bytes = read, "read from stream");
98109
if read == 0 {
99110
return Err(Error::Io(io::Error::new(
@@ -109,49 +120,149 @@ impl<R: AsyncReadExt + Unpin> DecryptingReader<R> {
109120
///
110121
/// This should only be used for reading the identification string.
111122
pub(crate) async fn read_u8_cleartext(&mut self) -> Result<u8, Error> {
112-
self.ensure_at_least(1).await?;
123+
assert!(self.decryption_key.is_none());
124+
125+
Self::ensure_at_least(&mut self.stream, &mut self.buf, &mut self.unread_start, 1).await?;
113126

114127
let byte = self.buf[self.unread_start];
115128
self.unread_start += 1;
116129
Ok(byte)
117130
}
118131

132+
pub(crate) fn set_decryption_key(
133+
&mut self,
134+
decryption_key: StreamingDecryptingKey,
135+
integrity_key: hmac::Key,
136+
) {
137+
self.decrypted_buf.clear();
138+
self.decryption_key = Some((decryption_key, integrity_key));
139+
}
140+
119141
pub(crate) async fn read_packet<'a>(&'a mut self) -> Result<Packet<'a>, Error> {
120142
// Compact the internal buffer
121143
if self.unread_start > 0 {
122144
self.buf.copy_within(self.unread_start.., 0);
123145
}
124146
self.buf.truncate(self.buf.len() - self.unread_start);
147+
self.decrypted_buf.clear();
125148
self.unread_start = 0;
126149

127-
self.ensure_at_least(4).await?;
128-
let Decoded {
129-
value: packet_length,
130-
next,
131-
} = PacketLength::decode(&self.buf[self.unread_start..self.unread_start + 4])?;
132-
assert!(next.is_empty());
133-
134-
self.ensure_at_least(4 + packet_length.inner).await?;
135-
let Decoded {
136-
value: packet,
137-
next,
138-
} = Packet::decode(
139-
&self.buf[self.unread_start..self.unread_start + 4 + packet_length.inner as usize],
140-
)?;
141-
assert!(next.is_empty());
142-
143-
self.unread_start += 4 + packet_length.inner as usize;
144-
145-
Ok(packet)
150+
let packet_number = self.packet_number;
151+
self.packet_number = self.packet_number.wrapping_add(1);
152+
153+
if let Some((decrypting_key, integrity_key)) = &mut self.decryption_key {
154+
let block_len = decrypting_key.algorithm().block_len();
155+
156+
Self::ensure_at_least(
157+
&mut self.stream,
158+
&mut self.buf,
159+
&mut self.unread_start,
160+
block_len as u32,
161+
)
162+
.await?;
163+
self.decrypted_buf.resize(self.buf.len() + block_len, 0);
164+
165+
let update = decrypting_key
166+
.update(
167+
&self.buf[self.unread_start..self.unread_start + block_len],
168+
&mut self.decrypted_buf[self.unread_start..self.unread_start + 2 * block_len],
169+
)
170+
.unwrap();
171+
assert_eq!(update.remainder().len(), block_len);
172+
173+
let Decoded {
174+
value: packet_length,
175+
next,
176+
} = PacketLength::decode(
177+
&self.decrypted_buf[self.unread_start..self.unread_start + 4],
178+
)?;
179+
assert!(next.is_empty());
180+
181+
Self::ensure_at_least(
182+
&mut self.stream,
183+
&mut self.buf,
184+
&mut self.unread_start,
185+
4 + packet_length.inner
186+
+ integrity_key.algorithm().digest_algorithm().output_len as u32,
187+
)
188+
.await?;
189+
190+
let update = decrypting_key
191+
.update(
192+
&self.buf[self.unread_start + block_len
193+
..self.unread_start + 4 + packet_length.inner as usize],
194+
&mut self.decrypted_buf[self.unread_start + block_len
195+
..self.unread_start + 4 + packet_length.inner as usize + block_len],
196+
)
197+
.unwrap();
198+
assert_eq!(update.remainder().len(), block_len);
199+
200+
let mut hmac_ctx = hmac::Context::with_key(integrity_key);
201+
hmac_ctx.update(&packet_number.to_be_bytes());
202+
hmac_ctx.update(
203+
&self.decrypted_buf
204+
[self.unread_start..self.unread_start + 4 + packet_length.inner as usize],
205+
);
206+
let actual_mac = hmac_ctx.sign();
207+
let expected_mac = &self.buf[self.unread_start + 4 + packet_length.inner as usize
208+
..self.unread_start
209+
+ 4
210+
+ packet_length.inner as usize
211+
+ integrity_key.algorithm().digest_algorithm().output_len];
212+
constant_time::verify_slices_are_equal(actual_mac.as_ref(), expected_mac).unwrap(); // FIXME report error
213+
214+
let Decoded {
215+
value: packet,
216+
next,
217+
} = Packet::decode(
218+
&self.decrypted_buf
219+
[self.unread_start..self.unread_start + 4 + packet_length.inner as usize],
220+
)?;
221+
assert!(next.is_empty());
222+
223+
self.unread_start += 4
224+
+ packet_length.inner as usize
225+
+ integrity_key.algorithm().digest_algorithm().output_len;
226+
227+
Ok(packet)
228+
} else {
229+
Self::ensure_at_least(&mut self.stream, &mut self.buf, &mut self.unread_start, 4)
230+
.await?;
231+
let Decoded {
232+
value: packet_length,
233+
next,
234+
} = PacketLength::decode(&self.buf[self.unread_start..self.unread_start + 4])?;
235+
assert!(next.is_empty());
236+
237+
Self::ensure_at_least(
238+
&mut self.stream,
239+
&mut self.buf,
240+
&mut self.unread_start,
241+
4 + packet_length.inner,
242+
)
243+
.await?;
244+
245+
let Decoded {
246+
value: packet,
247+
next,
248+
} = Packet::decode(
249+
&self.buf[self.unread_start..self.unread_start + 4 + packet_length.inner as usize],
250+
)?;
251+
assert!(next.is_empty());
252+
253+
self.unread_start += 4 + packet_length.inner as usize;
254+
255+
Ok(packet)
256+
}
146257
}
147258
}
148259

149260
pub(crate) struct Packet<'a> {
150261
pub(crate) payload: &'a [u8],
151262
}
152263

153-
impl Packet<'_> {
154-
pub(crate) fn builder(buf: &mut Vec<u8>) -> PacketBuilder<'_> {
264+
impl<'a> Packet<'a> {
265+
pub(crate) fn builder(buf: &'a mut Vec<u8>) -> PacketBuilder<'a> {
155266
let start = buf.len();
156267
buf.extend_from_slice(&[0, 0, 0, 0]); // packet_length
157268
buf.push(0); // padding_length

0 commit comments

Comments
 (0)