Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 86 additions & 48 deletions src/key_exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use std::str;

use aws_lc_rs::{
agreement::{self, EphemeralPrivateKey, UnparsedPublicKey, X25519},
digest,
cipher::{StreamingDecryptingKey, StreamingEncryptingKey, UnboundCipherKey, AES_128},
digest, hmac,
rand::{self, SystemRandom},
signature::KeyPair,
};
use tokio::io::AsyncWriteExt;
use tracing::{debug, error, warn};

use crate::{
proto::{read, with_mpint_bytes, Decode, Decoded, Encode, MessageType, Packet},
proto::{with_mpint_bytes, Decode, Decoded, Encode, MessageType, Packet},
Connection, Error,
};

Expand All @@ -25,11 +25,8 @@ impl EcdhKeyExchange {
mut exchange: digest::Context,
conn: &mut Connection,
) -> Result<(), ()> {
let (packet, _rest) = match read::<Packet<'_>>(&mut conn.stream, &mut conn.read_buf).await {
Ok(Decoded {
value: packet,
next,
}) => (packet, next.len()),
let packet = match conn.stream_read.read_packet().await {
Ok(packet) => packet,
Err(error) => {
warn!(addr = %conn.addr, %error, "failed to read packet");
return Err(());
Expand Down Expand Up @@ -106,21 +103,40 @@ impl EcdhKeyExchange {
},
};

conn.write_buf.clear();
let Ok(packet) = Packet::builder(&mut conn.write_buf)
.with_payload(&key_exchange_reply)
.without_mac()
else {
error!(addr = %conn.addr, "failed to build key exchange init packet");
if let Err(error) = conn
.stream_write
.write_packet(&key_exchange_reply, |_| {})
.await
{
warn!(addr = %conn.addr, %error, "failed to send version exchange");
return Err(());
};
}

if let Err(error) = conn.stream.write_all(packet).await {
warn!(addr = %conn.addr, %error, "failed to send version exchange");
let packet = match conn.stream_read.read_packet().await {
Ok(packet) => packet,
Err(error) => {
warn!(addr = %conn.addr, %error, "failed to read packet");
return Err(());
}
};
let Decoded {
value: r#type,
next: _,
} = MessageType::decode(packet.payload)
.map_err(|error| warn!(addr = %conn.addr, %error, "failed to read packet type"))?;
if r#type != MessageType::NewKeys {
warn!(addr = %conn.addr, "unexpected message type {:?}", r#type);
return Err(());
}

// FIXME wait for and send newkey packet
if let Err(error) = conn
.stream_write
.write_packet(&MessageType::NewKeys, |_| {})
.await
{
warn!(addr = %conn.addr, %error, "failed to send newkeys packet");
return Err(());
}

// The first exchange hash is used as session id.
let session_id = self.session_id.as_ref().unwrap_or(&exchange_hash);
Expand All @@ -129,12 +145,51 @@ impl EcdhKeyExchange {
exchange_hash,
session_id,
};
#[expect(clippy::unnecessary_operation)]
RawKeySet {
let raw_keys = RawKeySet {
client_to_server: RawKeys::client_to_server(&derivation),
server_to_client: RawKeys::server_to_client(&derivation),
};

conn.stream_read.set_decryption_key(
StreamingDecryptingKey::ctr(
UnboundCipherKey::new(
&AES_128,
&raw_keys.client_to_server.encryption_key.as_ref()[..16],
)
.unwrap(),
aws_lc_rs::cipher::DecryptionContext::Iv128(
raw_keys.client_to_server.initial_iv.as_ref()[..16]
.try_into()
.unwrap(),
),
)
.unwrap(),
hmac::Key::new(
hmac::HMAC_SHA256,
&raw_keys.client_to_server.integrity_key.as_ref()[..32],
),
);

conn.stream_write.set_encryption_key(
StreamingEncryptingKey::less_safe_ctr(
UnboundCipherKey::new(
&AES_128,
&raw_keys.server_to_client.encryption_key.as_ref()[..16],
)
.unwrap(),
aws_lc_rs::cipher::EncryptionContext::Iv128(
raw_keys.server_to_client.initial_iv.as_ref()[..16]
.try_into()
.unwrap(),
),
)
.unwrap(),
hmac::Key::new(
hmac::HMAC_SHA256,
&raw_keys.server_to_client.integrity_key.as_ref()[..32],
),
);

Ok(())
}
}
Expand Down Expand Up @@ -243,11 +298,8 @@ impl KeyExchange {
exchange: &mut digest::Context,
conn: &mut Connection,
) -> Result<EcdhKeyExchange, ()> {
let (packet, rest) = match read::<Packet<'_>>(&mut conn.stream, &mut conn.read_buf).await {
Ok(Decoded {
value: packet,
next,
}) => (packet, next.len()),
let packet = match conn.stream_read.read_packet().await {
Ok(packet) => packet,
Err(error) => {
warn!(addr = %conn.addr, %error, "failed to read packet");
return Err(());
Expand All @@ -273,21 +325,15 @@ impl KeyExchange {
}
};

conn.write_buf.clear();
let builder = Packet::builder(&mut conn.write_buf).with_payload(&key_exchange_init);

if let Ok(kex_init_payload) = builder.payload() {
exchange.update(&(kex_init_payload.len() as u32).to_be_bytes());
exchange.update(kex_init_payload);
};

let Ok(packet) = builder.without_mac() else {
error!(addr = %conn.addr, "failed to build key exchange init packet");
return Err(());
};

if let Err(error) = conn.stream.write_all(packet).await {
warn!(addr = %conn.addr, %error, "failed to send version exchange");
if let Err(error) = conn
.stream_write
.write_packet(&key_exchange_init, |kex_init_payload| {
exchange.update(&(kex_init_payload.len() as u32).to_be_bytes());
exchange.update(kex_init_payload);
})
.await
{
warn!(addr = %conn.addr, %error, "failed to send key exchange init packet");
return Err(());
}

Expand All @@ -307,12 +353,6 @@ impl KeyExchange {
return Err(());
}

if rest > 0 {
let start = conn.read_buf.len() - rest;
conn.read_buf.copy_within(start.., 0);
}
conn.read_buf.truncate(rest);

Ok(EcdhKeyExchange {
session_id: self.session_id,
})
Expand Down Expand Up @@ -561,13 +601,11 @@ impl<'a, T: From<&'a str>> Decode<'a> for Vec<T> {
/// The raw hashes from which we will derive the crypto keys.
///
/// <https://www.rfc-editor.org/rfc/rfc4253#section-7.2>
#[expect(dead_code)] // FIXME implement encryption/decryption and MAC
struct RawKeySet {
client_to_server: RawKeys,
server_to_client: RawKeys,
}

#[expect(dead_code)] // FIXME implement encryption/decryption and MAC
struct RawKeys {
initial_iv: digest::Digest,
encryption_key: digest::Digest,
Expand Down
Loading
Loading