diff --git a/src/key_exchange.rs b/src/key_exchange.rs index eae934fb..87a81528 100644 --- a/src/key_exchange.rs +++ b/src/key_exchange.rs @@ -2,15 +2,15 @@ use std::str; use aws_lc_rs::{ agreement::{self, EphemeralPrivateKey, UnparsedPublicKey, X25519}, - digest, + cipher::{StreamingDecryptingKey, StreamingEncryptingKey, UnboundCipherKey, AES_128}, + digest, hmac, rand::{self, SystemRandom}, signature::KeyPair, }; -use tokio::io::AsyncWriteExt; use tracing::{debug, error, warn}; use crate::{ - proto::{read, with_mpint_bytes, Decode, Decoded, Encode, MessageType, Packet}, + proto::{with_mpint_bytes, Decode, Decoded, Encode, MessageType, Packet}, Connection, Error, }; @@ -25,11 +25,8 @@ impl EcdhKeyExchange { mut exchange: digest::Context, conn: &mut Connection, ) -> Result<(), ()> { - let (packet, _rest) = match read::>(&mut conn.stream, &mut conn.read_buf).await { - Ok(Decoded { - value: packet, - next, - }) => (packet, next.len()), + let packet = match conn.stream_read.read_packet().await { + Ok(packet) => packet, Err(error) => { warn!(addr = %conn.addr, %error, "failed to read packet"); return Err(()); @@ -106,21 +103,40 @@ impl EcdhKeyExchange { }, }; - conn.write_buf.clear(); - let Ok(packet) = Packet::builder(&mut conn.write_buf) - .with_payload(&key_exchange_reply) - .without_mac() - else { - error!(addr = %conn.addr, "failed to build key exchange init packet"); + if let Err(error) = conn + .stream_write + .write_packet(&key_exchange_reply, |_| {}) + .await + { + warn!(addr = %conn.addr, %error, "failed to send version exchange"); return Err(()); - }; + } - if let Err(error) = conn.stream.write_all(packet).await { - warn!(addr = %conn.addr, %error, "failed to send version exchange"); + let packet = match conn.stream_read.read_packet().await { + Ok(packet) => packet, + Err(error) => { + warn!(addr = %conn.addr, %error, "failed to read packet"); + return Err(()); + } + }; + let Decoded { + value: r#type, + next: _, + } = MessageType::decode(packet.payload) + .map_err(|error| warn!(addr = %conn.addr, %error, "failed to read packet type"))?; + if r#type != MessageType::NewKeys { + warn!(addr = %conn.addr, "unexpected message type {:?}", r#type); return Err(()); } - // FIXME wait for and send newkey packet + if let Err(error) = conn + .stream_write + .write_packet(&MessageType::NewKeys, |_| {}) + .await + { + warn!(addr = %conn.addr, %error, "failed to send newkeys packet"); + return Err(()); + } // The first exchange hash is used as session id. let session_id = self.session_id.as_ref().unwrap_or(&exchange_hash); @@ -129,12 +145,51 @@ impl EcdhKeyExchange { exchange_hash, session_id, }; - #[expect(clippy::unnecessary_operation)] - RawKeySet { + let raw_keys = RawKeySet { client_to_server: RawKeys::client_to_server(&derivation), server_to_client: RawKeys::server_to_client(&derivation), }; + conn.stream_read.set_decryption_key( + StreamingDecryptingKey::ctr( + UnboundCipherKey::new( + &AES_128, + &raw_keys.client_to_server.encryption_key.as_ref()[..16], + ) + .unwrap(), + aws_lc_rs::cipher::DecryptionContext::Iv128( + raw_keys.client_to_server.initial_iv.as_ref()[..16] + .try_into() + .unwrap(), + ), + ) + .unwrap(), + hmac::Key::new( + hmac::HMAC_SHA256, + &raw_keys.client_to_server.integrity_key.as_ref()[..32], + ), + ); + + conn.stream_write.set_encryption_key( + StreamingEncryptingKey::less_safe_ctr( + UnboundCipherKey::new( + &AES_128, + &raw_keys.server_to_client.encryption_key.as_ref()[..16], + ) + .unwrap(), + aws_lc_rs::cipher::EncryptionContext::Iv128( + raw_keys.server_to_client.initial_iv.as_ref()[..16] + .try_into() + .unwrap(), + ), + ) + .unwrap(), + hmac::Key::new( + hmac::HMAC_SHA256, + &raw_keys.server_to_client.integrity_key.as_ref()[..32], + ), + ); + Ok(()) } } @@ -243,11 +298,8 @@ impl KeyExchange { exchange: &mut digest::Context, conn: &mut Connection, ) -> Result { - let (packet, rest) = match read::>(&mut conn.stream, &mut conn.read_buf).await { - Ok(Decoded { - value: packet, - next, - }) => (packet, next.len()), + let packet = match conn.stream_read.read_packet().await { + Ok(packet) => packet, Err(error) => { warn!(addr = %conn.addr, %error, "failed to read packet"); return Err(()); @@ -273,21 +325,15 @@ impl KeyExchange { } }; - conn.write_buf.clear(); - let builder = Packet::builder(&mut conn.write_buf).with_payload(&key_exchange_init); - - if let Ok(kex_init_payload) = builder.payload() { - exchange.update(&(kex_init_payload.len() as u32).to_be_bytes()); - exchange.update(kex_init_payload); - }; - - let Ok(packet) = builder.without_mac() else { - error!(addr = %conn.addr, "failed to build key exchange init packet"); - return Err(()); - }; - - if let Err(error) = conn.stream.write_all(packet).await { - warn!(addr = %conn.addr, %error, "failed to send version exchange"); + if let Err(error) = conn + .stream_write + .write_packet(&key_exchange_init, |kex_init_payload| { + exchange.update(&(kex_init_payload.len() as u32).to_be_bytes()); + exchange.update(kex_init_payload); + }) + .await + { + warn!(addr = %conn.addr, %error, "failed to send key exchange init packet"); return Err(()); } @@ -307,12 +353,6 @@ impl KeyExchange { return Err(()); } - if rest > 0 { - let start = conn.read_buf.len() - rest; - conn.read_buf.copy_within(start.., 0); - } - conn.read_buf.truncate(rest); - Ok(EcdhKeyExchange { session_id: self.session_id, }) @@ -561,13 +601,11 @@ impl<'a, T: From<&'a str>> Decode<'a> for Vec { /// The raw hashes from which we will derive the crypto keys. /// /// -#[expect(dead_code)] // FIXME implement encryption/decryption and MAC struct RawKeySet { client_to_server: RawKeys, server_to_client: RawKeys, } -#[expect(dead_code)] // FIXME implement encryption/decryption and MAC struct RawKeys { initial_iv: digest::Digest, encryption_key: digest::Digest, diff --git a/src/lib.rs b/src/lib.rs index a2c3d6be..4056e581 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,23 +3,23 @@ use std::{io, str, sync::Arc}; use aws_lc_rs::{digest, signature::Ed25519KeyPair}; use thiserror::Error; -use tokio::{io::AsyncWriteExt, net::TcpStream}; +use tokio::{ + io::AsyncReadExt, + net::{tcp, TcpStream}, +}; use tracing::{debug, warn}; mod key_exchange; use key_exchange::KeyExchange; mod proto; -use proto::{read, Decode, Decoded, Encode}; - -use crate::proto::Packet; +use proto::{DecryptingReader, Encode, EncryptingWriter, Packet}; /// A single SSH connection pub struct Connection { - stream: TcpStream, + stream_read: DecryptingReader, + stream_write: EncryptingWriter, addr: SocketAddr, host_key: Arc, - read_buf: Vec, - write_buf: Vec, } impl Connection { @@ -30,12 +30,14 @@ impl Connection { host_key: Arc, ) -> anyhow::Result { stream.set_nodelay(true)?; + + let (stream_read, stream_write) = stream.into_split(); + Ok(Self { - stream, + stream_read: DecryptingReader::new(stream_read), + stream_write: EncryptingWriter::new(stream_write), addr, host_key, - read_buf: Vec::with_capacity(16_384), - write_buf: Vec::with_capacity(16_384), }) } @@ -58,7 +60,7 @@ impl Connection { todo!(); } - pub async fn connect( + pub(crate) async fn connect( stream: TcpStream, addr: SocketAddr, host_key: Arc, @@ -67,12 +69,12 @@ impl Connection { todo!() } - pub async fn recv_packet(&mut self) -> anyhow::Result> { - todo!() + pub(crate) async fn recv_packet(&mut self) -> anyhow::Result> { + Ok(self.stream_read.read_packet().await?) } - pub async fn send_packet(&mut self, packet: impl Encode) -> anyhow::Result<()> { - todo!() + pub(crate) async fn send_packet(&mut self, payload: &impl Encode) -> anyhow::Result<()> { + Ok(self.stream_write.write_packet(payload, |_| {}).await?) } } @@ -85,48 +87,50 @@ impl VersionExchange { exchange: &mut digest::Context, conn: &mut Connection, ) -> Result { - let (ident, rest) = - match read::>(&mut conn.stream, &mut conn.read_buf).await { - Ok(Decoded { value: ident, next }) => { - debug!(addr = %conn.addr, ?ident, "received identification"); - (ident, next.len()) - } - Err(error) => { - warn!(addr = %conn.addr, %error, "failed to read version exchange"); - return Err(()); - } - }; + let ident_bytes = match Identification::read_from_stream(&mut conn.stream_read).await { + Ok(ident_bytes) => ident_bytes, + Err(error) => { + warn!(addr = %conn.addr, %error, "failed to read version exchange"); + return Err(()); + } + }; + let ident = match Identification::decode(&ident_bytes) { + Ok(ident) => { + debug!(addr = %conn.addr, ?ident, "received identification"); + ident + } + Err(error) => { + warn!(addr = %conn.addr, %error, "failed to read version exchange"); + return Err(()); + } + }; if ident.protocol != PROTOCOL { warn!(addr = %conn.addr, ?ident, "unsupported protocol version"); return Err(()); } - let v_c_len = conn.read_buf.len() - rest - 2; - if let Some(v_c) = conn.read_buf.get(..v_c_len) { - exchange.update(&(v_c.len() as u32).to_be_bytes()); - exchange.update(v_c); - } + let v_c = &ident_bytes; + exchange.update(&(v_c.len() as u32).to_be_bytes()); + exchange.update(v_c); let ident = Identification::outgoing(); - ident.encode(&mut conn.write_buf); - if let Err(error) = conn.stream.write_all(&conn.write_buf).await { + let server_ident_bytes = ident.encode(); + if let Err(error) = conn + .stream_write + .write_raw_cleartext(&server_ident_bytes) + .await + { warn!(addr = %conn.addr, %error, "failed to send version exchange"); return Err(()); } - let v_s_len = conn.write_buf.len() - 2; - if let Some(v_s) = conn.write_buf.get(..v_s_len) { + let v_s_len = server_ident_bytes.len() - 2; + if let Some(v_s) = server_ident_bytes.get(..v_s_len) { exchange.update(&(v_s.len() as u32).to_be_bytes()); exchange.update(v_s); } - if rest > 0 { - let start = conn.read_buf.len() - rest; - conn.read_buf.copy_within(start.., 0); - } - conn.read_buf.truncate(rest); - Ok(KeyExchange::for_new_session()) } } @@ -148,19 +152,32 @@ impl Identification<'_> { } } -impl<'a> Decode<'a> for Identification<'a> { - fn decode(bytes: &'a [u8]) -> Result, Error> { +impl<'a> Identification<'a> { + /// Read the identification string as raw bytes with the CRLF stripped off. + async fn read_from_stream( + stream: &mut DecryptingReader, + ) -> Result, Error> { + let mut data = vec![]; + loop { + data.push(stream.read_u8_cleartext().await?); + if data.len() > 255 { + return Err(IdentificationError::TooLong.into()); + } + if let Some((_, b"\r\n")) = data.split_last_chunk::<2>() { + data.pop().unwrap(); + data.pop().unwrap(); + break; + } + } + debug!(bytes = data.len(), "read from stream"); + Ok(data) + } + + fn decode(bytes: &'a [u8]) -> Result { let Ok(message) = str::from_utf8(bytes) else { return Err(IdentificationError::InvalidUtf8.into()); }; - let Some((message, next)) = message.split_once("\r\n") else { - return Err(match message.len() > 256 { - true => IdentificationError::TooLong.into(), - false => Error::Incomplete(None), - }); - }; - let Some(rest) = message.strip_prefix("SSH-") else { return Err(IdentificationError::NoSsh.into()); }; @@ -174,21 +191,15 @@ impl<'a> Decode<'a> for Identification<'a> { None => (rest, ""), }; - let out = Self { + Ok(Self { protocol, software, comments, - }; - - Ok(Decoded { - value: out, - next: next.as_bytes(), }) } -} -impl Encode for Identification<'_> { - fn encode(&self, buf: &mut Vec) { + fn encode(&self) -> Vec { + let mut buf = vec![]; buf.extend_from_slice(b"SSH-"); buf.extend_from_slice(self.protocol.as_bytes()); buf.push(b'-'); @@ -198,6 +209,7 @@ impl Encode for Identification<'_> { buf.extend_from_slice(self.comments.as_bytes()); } buf.extend_from_slice(b"\r\n"); + buf } } diff --git a/src/main.rs b/src/main.rs index 3e6bddf5..ab8c5a4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -59,7 +59,7 @@ async fn main() -> anyhow::Result<()> { Ok((stream, addr)) => { debug!(%addr, "accepted connection"); let conn = Connection::new(stream, addr, host_key.clone())?; - tokio::spawn(conn.run()); + conn.run().await; // FIXME(aws/aws-lc-rs#975) use tokio::spawn() once StreamingDecryptingKey is Send } Err(error) => { warn!(%error, "failed to accept connection"); diff --git a/src/proto.rs b/src/proto.rs index 117259f0..4b6021da 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,7 +1,11 @@ use core::iter; +use std::io; -use aws_lc_rs::rand; -use tokio::io::AsyncReadExt; +use aws_lc_rs::{ + cipher::{StreamingDecryptingKey, StreamingEncryptingKey}, + constant_time, hmac, rand, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::debug; use crate::Error; @@ -67,12 +71,280 @@ impl From for MessageType { } } +/// A reader which decrypts data on the fly. +/// +/// ```text +/// +---- unread_start +/// v +/// |read|unread and not yet decrypted| +/// ``` +pub(crate) struct DecryptingReader { + stream: R, + buf: Vec, + decrypted_buf: Vec, + unread_start: usize, + + packet_number: u32, + decryption_key: Option<(StreamingDecryptingKey, hmac::Key)>, +} + +impl DecryptingReader { + pub(crate) fn new(stream: R) -> Self { + Self { + stream, + buf: Vec::with_capacity(16_384), + decrypted_buf: Vec::with_capacity(16_384), + unread_start: 0, + packet_number: 0, + decryption_key: None, + } + } + + async fn ensure_at_least( + stream: &mut R, + buf: &mut Vec, + unread_start: &mut usize, + n: u32, + ) -> Result<(), Error> { + while buf.len() - *unread_start < n as usize { + let read = stream.read_buf(buf).await?; + debug!(bytes = read, "read from stream"); + if read == 0 { + return Err(Error::Io(io::Error::new( + io::ErrorKind::UnexpectedEof, + "EOF", + ))); + } + } + Ok(()) + } + + /// Read a single byte without packet structure or encryption. + /// + /// This should only be used for reading the identification string. + pub(crate) async fn read_u8_cleartext(&mut self) -> Result { + assert!(self.decryption_key.is_none()); + + Self::ensure_at_least(&mut self.stream, &mut self.buf, &mut self.unread_start, 1).await?; + + let byte = self.buf[self.unread_start]; + self.unread_start += 1; + Ok(byte) + } + + pub(crate) fn set_decryption_key( + &mut self, + decryption_key: StreamingDecryptingKey, + integrity_key: hmac::Key, + ) { + self.decrypted_buf.clear(); + self.decryption_key = Some((decryption_key, integrity_key)); + } + + pub(crate) async fn read_packet<'a>(&'a mut self) -> Result, Error> { + // Compact the internal buffer + if self.unread_start > 0 { + self.buf.copy_within(self.unread_start.., 0); + } + self.buf.truncate(self.buf.len() - self.unread_start); + self.decrypted_buf.clear(); + self.unread_start = 0; + + let packet_number = self.packet_number; + self.packet_number = self.packet_number.wrapping_add(1); + + if let Some((decrypting_key, integrity_key)) = &mut self.decryption_key { + let block_len = decrypting_key.algorithm().block_len(); + + Self::ensure_at_least( + &mut self.stream, + &mut self.buf, + &mut self.unread_start, + block_len as u32, + ) + .await?; + self.decrypted_buf.resize(self.buf.len() + block_len, 0); + + let update = decrypting_key + .update( + &self.buf[self.unread_start..self.unread_start + block_len], + &mut self.decrypted_buf[self.unread_start..self.unread_start + 2 * block_len], + ) + .unwrap(); + assert_eq!(update.remainder().len(), block_len); + + let Decoded { + value: packet_length, + next, + } = PacketLength::decode( + &self.decrypted_buf[self.unread_start..self.unread_start + 4], + )?; + assert!(next.is_empty()); + + Self::ensure_at_least( + &mut self.stream, + &mut self.buf, + &mut self.unread_start, + 4 + packet_length.inner + + integrity_key.algorithm().digest_algorithm().output_len as u32, + ) + .await?; + + let update = decrypting_key + .update( + &self.buf[self.unread_start + block_len + ..self.unread_start + 4 + packet_length.inner as usize], + &mut self.decrypted_buf[self.unread_start + block_len + ..self.unread_start + 4 + packet_length.inner as usize + block_len], + ) + .unwrap(); + assert_eq!(update.remainder().len(), block_len); + + let mut hmac_ctx = hmac::Context::with_key(integrity_key); + hmac_ctx.update(&packet_number.to_be_bytes()); + hmac_ctx.update( + &self.decrypted_buf + [self.unread_start..self.unread_start + 4 + packet_length.inner as usize], + ); + let actual_mac = hmac_ctx.sign(); + let expected_mac = &self.buf[self.unread_start + 4 + packet_length.inner as usize + ..self.unread_start + + 4 + + packet_length.inner as usize + + integrity_key.algorithm().digest_algorithm().output_len]; + constant_time::verify_slices_are_equal(actual_mac.as_ref(), expected_mac).unwrap(); // FIXME report error + + let Decoded { + value: packet, + next, + } = Packet::decode( + &self.decrypted_buf + [self.unread_start..self.unread_start + 4 + packet_length.inner as usize], + )?; + assert!(next.is_empty()); + + self.unread_start += 4 + + packet_length.inner as usize + + integrity_key.algorithm().digest_algorithm().output_len; + + Ok(packet) + } else { + Self::ensure_at_least(&mut self.stream, &mut self.buf, &mut self.unread_start, 4) + .await?; + let Decoded { + value: packet_length, + next, + } = PacketLength::decode(&self.buf[self.unread_start..self.unread_start + 4])?; + assert!(next.is_empty()); + + Self::ensure_at_least( + &mut self.stream, + &mut self.buf, + &mut self.unread_start, + 4 + packet_length.inner, + ) + .await?; + + let Decoded { + value: packet, + next, + } = Packet::decode( + &self.buf[self.unread_start..self.unread_start + 4 + packet_length.inner as usize], + )?; + assert!(next.is_empty()); + + self.unread_start += 4 + packet_length.inner as usize; + + Ok(packet) + } + } +} + +pub(crate) struct EncryptingWriter { + stream: W, + buf: Vec, + encrypted_buf: Vec, + + packet_number: u32, + encryption_key: Option<(StreamingEncryptingKey, hmac::Key)>, +} + +impl EncryptingWriter { + pub(crate) fn new(stream: W) -> Self { + Self { + stream, + buf: Vec::with_capacity(16_384), + encrypted_buf: Vec::with_capacity(16_384), + packet_number: 0, + encryption_key: None, + } + } + + /// Write raw bytes without packet structure or encryption. + /// + /// This should only be used for writing the identification string. + pub(crate) async fn write_raw_cleartext(&mut self, bytes: &[u8]) -> Result<(), Error> { + assert!(self.encryption_key.is_none()); + self.stream.write_all(bytes).await?; + Ok(()) + } + + pub(crate) fn set_encryption_key( + &mut self, + encryption_key: StreamingEncryptingKey, + integrity_key: hmac::Key, + ) { + self.encryption_key = Some((encryption_key, integrity_key)); + } + + /// Write a packet. Returns written [`Packet`]. + pub(crate) async fn write_packet( + &mut self, + payload: &impl Encode, + update_exchange_hash: impl FnOnce(&[u8]), + ) -> Result<(), Error> { + self.buf.clear(); + self.encrypted_buf.clear(); + + let packet_number = self.packet_number; + self.packet_number = self.packet_number.wrapping_add(1); + + let packet = Packet::builder(&mut self.buf).with_payload(payload); + update_exchange_hash(packet.payload()?); + + if let Some((encryption_key, integrity_key)) = &mut self.encryption_key { + let block_len = encryption_key.algorithm().block_len(); + + let data = packet.without_mac()?; + + self.encrypted_buf.resize(data.len() + block_len, 0); + let update = encryption_key + .update(data, &mut self.encrypted_buf) + .unwrap(); + assert_eq!(update.remainder().len(), block_len); + self.encrypted_buf.truncate(data.len()); + + let mut hmac_ctx = hmac::Context::with_key(integrity_key); + hmac_ctx.update(&packet_number.to_be_bytes()); + hmac_ctx.update(data); + let mac = hmac_ctx.sign(); + self.encrypted_buf.extend_from_slice(mac.as_ref()); + + self.stream.write_all(&self.encrypted_buf).await?; + } else { + self.stream.write_all(packet.without_mac()?).await?; + }; + + Ok(()) + } +} + pub(crate) struct Packet<'a> { pub(crate) payload: &'a [u8], } -impl Packet<'_> { - pub(crate) fn builder(buf: &mut Vec) -> PacketBuilder<'_> { +impl<'a> Packet<'a> { + pub(crate) fn builder(buf: &'a mut Vec) -> PacketBuilder<'a> { let start = buf.len(); buf.extend_from_slice(&[0, 0, 0, 0]); // packet_length buf.push(0); // padding_length @@ -80,7 +352,7 @@ impl Packet<'_> { } } -impl<'a> Decode<'a> for Packet<'a> { +impl<'a> Packet<'a> { fn decode(bytes: &'a [u8]) -> Result, Error> { let Decoded { value: packet_length, @@ -184,8 +456,6 @@ impl<'a> PacketBuilderWithPayload<'a> { } } - buf.extend_from_slice(&[]); // mac - let packet_len = (buf.len() - start - 4) as u32; if let Some(packet_length_dst) = buf.get_mut(start..start + 4) { packet_length_dst.copy_from_slice(&packet_len.to_be_bytes()); @@ -309,15 +579,6 @@ pub(crate) trait Encode { fn encode(&self, buf: &mut Vec); } -pub(crate) async fn read<'a, T: Decode<'a>>( - reader: &mut (impl AsyncReadExt + Unpin), - buf: &'a mut Vec, -) -> Result, Error> { - let read = reader.read_buf(buf).await?; - debug!(bytes = read, "read from stream"); - T::decode(buf) -} - pub(crate) trait Decode<'a>: Sized { fn decode(bytes: &'a [u8]) -> Result, Error>; }