Skip to content

Commit 3fc5c5b

Browse files
committed
Introduce EncryptingReader which will in the future handle encryption
1 parent b68a273 commit 3fc5c5b

File tree

3 files changed

+92
-53
lines changed

3 files changed

+92
-53
lines changed

src/key_exchange.rs

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use aws_lc_rs::{
77
rand::{self, SystemRandom},
88
signature::KeyPair,
99
};
10-
use tokio::io::AsyncWriteExt;
1110
use tracing::{debug, error, warn};
1211

1312
use crate::{
@@ -104,16 +103,11 @@ impl EcdhKeyExchange {
104103
},
105104
};
106105

107-
conn.write_buf.clear();
108-
let Ok(packet) = Packet::builder(&mut conn.write_buf)
109-
.with_payload(&key_exchange_reply)
110-
.without_mac()
111-
else {
112-
error!(addr = %conn.addr, "failed to build key exchange init packet");
113-
return Err(());
114-
};
115-
116-
if let Err(error) = conn.stream_write.write_all(packet).await {
106+
if let Err(error) = conn
107+
.stream_write
108+
.write_packet(&key_exchange_reply, |_| {})
109+
.await
110+
{
117111
warn!(addr = %conn.addr, %error, "failed to send version exchange");
118112
return Err(());
119113
}
@@ -135,16 +129,11 @@ impl EcdhKeyExchange {
135129
return Err(());
136130
}
137131

138-
conn.write_buf.clear();
139-
let Ok(packet) = Packet::builder(&mut conn.write_buf)
140-
.with_payload(&MessageType::NewKeys)
141-
.without_mac()
142-
else {
143-
error!(addr = %conn.addr, "failed to build newkeys packet");
144-
return Err(());
145-
};
146-
147-
if let Err(error) = conn.stream_write.write_all(packet).await {
132+
if let Err(error) = conn
133+
.stream_write
134+
.write_packet(&MessageType::NewKeys, |_| {})
135+
.await
136+
{
148137
warn!(addr = %conn.addr, %error, "failed to send newkeys packet");
149138
return Err(());
150139
}
@@ -316,21 +305,15 @@ impl KeyExchange {
316305
}
317306
};
318307

319-
conn.write_buf.clear();
320-
let builder = Packet::builder(&mut conn.write_buf).with_payload(&key_exchange_init);
321-
322-
if let Ok(kex_init_payload) = builder.payload() {
323-
exchange.update(&(kex_init_payload.len() as u32).to_be_bytes());
324-
exchange.update(kex_init_payload);
325-
};
326-
327-
let Ok(packet) = builder.without_mac() else {
328-
error!(addr = %conn.addr, "failed to build key exchange init packet");
329-
return Err(());
330-
};
331-
332-
if let Err(error) = conn.stream_write.write_all(packet).await {
333-
warn!(addr = %conn.addr, %error, "failed to send version exchange");
308+
if let Err(error) = conn
309+
.stream_write
310+
.write_packet(&key_exchange_init, |kex_init_payload| {
311+
exchange.update(&(kex_init_payload.len() as u32).to_be_bytes());
312+
exchange.update(kex_init_payload);
313+
})
314+
.await
315+
{
316+
warn!(addr = %conn.addr, %error, "failed to send key exchange init packet");
334317
return Err(());
335318
}
336319

src/lib.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,22 @@ use std::{io, str, sync::Arc};
44
use aws_lc_rs::{digest, signature::Ed25519KeyPair};
55
use thiserror::Error;
66
use tokio::{
7-
io::{AsyncReadExt, AsyncWriteExt},
7+
io::AsyncReadExt,
88
net::{tcp, TcpStream},
99
};
1010
use tracing::{debug, warn};
1111

1212
mod key_exchange;
1313
use key_exchange::KeyExchange;
1414
mod proto;
15-
use proto::{DecryptingReader, Encode};
16-
17-
use crate::proto::Packet;
15+
use proto::{DecryptingReader, Encode, EncryptingWriter, Packet};
1816

1917
/// A single SSH connection
2018
pub struct Connection {
2119
stream_read: DecryptingReader<tcp::OwnedReadHalf>,
22-
stream_write: tcp::OwnedWriteHalf,
20+
stream_write: EncryptingWriter<tcp::OwnedWriteHalf>,
2321
addr: SocketAddr,
2422
host_key: Arc<Ed25519KeyPair>,
25-
write_buf: Vec<u8>,
2623
}
2724

2825
impl Connection {
@@ -38,10 +35,9 @@ impl Connection {
3835

3936
Ok(Self {
4037
stream_read: DecryptingReader::new(stream_read),
41-
stream_write,
38+
stream_write: EncryptingWriter::new(stream_write),
4239
addr,
4340
host_key,
44-
write_buf: Vec::with_capacity(16_384),
4541
})
4642
}
4743

@@ -119,14 +115,18 @@ impl VersionExchange {
119115
exchange.update(v_c);
120116

121117
let ident = Identification::outgoing();
122-
ident.encode(&mut conn.write_buf);
123-
if let Err(error) = conn.stream_write.write_all(&conn.write_buf).await {
118+
let server_ident_bytes = ident.encode();
119+
if let Err(error) = conn
120+
.stream_write
121+
.write_raw_cleartext(&server_ident_bytes)
122+
.await
123+
{
124124
warn!(addr = %conn.addr, %error, "failed to send version exchange");
125125
return Err(());
126126
}
127127

128-
let v_s_len = conn.write_buf.len() - 2;
129-
if let Some(v_s) = conn.write_buf.get(..v_s_len) {
128+
let v_s_len = server_ident_bytes.len() - 2;
129+
if let Some(v_s) = server_ident_bytes.get(..v_s_len) {
130130
exchange.update(&(v_s.len() as u32).to_be_bytes());
131131
exchange.update(v_s);
132132
}
@@ -197,10 +197,9 @@ impl<'a> Identification<'a> {
197197
comments,
198198
})
199199
}
200-
}
201200

202-
impl Encode for Identification<'_> {
203-
fn encode(&self, buf: &mut Vec<u8>) {
201+
fn encode(&self) -> Vec<u8> {
202+
let mut buf = vec![];
204203
buf.extend_from_slice(b"SSH-");
205204
buf.extend_from_slice(self.protocol.as_bytes());
206205
buf.push(b'-');
@@ -210,6 +209,7 @@ impl Encode for Identification<'_> {
210209
buf.extend_from_slice(self.comments.as_bytes());
211210
}
212211
buf.extend_from_slice(b"\r\n");
212+
buf
213213
}
214214
}
215215

src/proto.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
use core::iter;
22
use std::io;
33

4-
use aws_lc_rs::{cipher::StreamingDecryptingKey, constant_time, hmac, rand};
5-
use tokio::io::AsyncReadExt;
4+
use aws_lc_rs::{
5+
cipher::{StreamingDecryptingKey, StreamingEncryptingKey},
6+
constant_time, hmac, rand,
7+
};
8+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
69
use tracing::debug;
710

811
use crate::Error;
@@ -257,6 +260,59 @@ impl<R: AsyncReadExt + Unpin> DecryptingReader<R> {
257260
}
258261
}
259262

263+
pub(crate) struct EncryptingWriter<W: AsyncWriteExt + Unpin> {
264+
stream: W,
265+
buf: Vec<u8>,
266+
encrypted_buf: Vec<u8>,
267+
268+
packet_number: u32,
269+
encryption_key: Option<(StreamingEncryptingKey, hmac::Key)>,
270+
}
271+
272+
impl<W: AsyncWriteExt + Unpin> EncryptingWriter<W> {
273+
pub(crate) fn new(stream: W) -> Self {
274+
Self {
275+
stream,
276+
buf: Vec::with_capacity(16_384),
277+
encrypted_buf: Vec::with_capacity(16_384),
278+
packet_number: 0,
279+
encryption_key: None,
280+
}
281+
}
282+
283+
/// Write raw bytes without packet structure or encryption.
284+
///
285+
/// This should only be used for writing the identification string.
286+
pub(crate) async fn write_raw_cleartext(&mut self, bytes: &[u8]) -> Result<(), Error> {
287+
assert!(self.encryption_key.is_none());
288+
self.stream.write_all(bytes).await?;
289+
Ok(())
290+
}
291+
292+
/// Write a packet. Returns written [`Packet`].
293+
pub(crate) async fn write_packet(
294+
&mut self,
295+
payload: &impl Encode,
296+
update_exchange_hash: impl FnOnce(&[u8]),
297+
) -> Result<(), Error> {
298+
self.buf.clear();
299+
self.encrypted_buf.clear();
300+
301+
let packet_number = self.packet_number;
302+
self.packet_number = self.packet_number.wrapping_add(1);
303+
304+
if let Some((encryption_key, integrity_key)) = &mut self.encryption_key {
305+
todo!()
306+
} else {
307+
let packet = Packet::builder(&mut self.buf).with_payload(payload);
308+
update_exchange_hash(packet.payload()?);
309+
self.stream.write_all(packet.without_mac()?).await?;
310+
};
311+
312+
Ok(())
313+
}
314+
}
315+
260316
pub(crate) struct Packet<'a> {
261317
pub(crate) payload: &'a [u8],
262318
}

0 commit comments

Comments
 (0)