Skip to content

Commit 9c1377f

Browse files
davidv1992bjorn3
authored andcommitted
Initial implementation of service handling.
1 parent 82972fd commit 9c1377f

File tree

4 files changed

+268
-8
lines changed

4 files changed

+268
-8
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ aws-lc-rs = { version = "1.12", features = ["prebuilt-nasm"] }
1111
clap = { version = "4.5.24", features = ["derive"] }
1212
listenfd = "1.0.2"
1313
thiserror = "2.0.9"
14-
tokio = { version = "1.42", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
14+
tokio = { version = "1.42", features = ["io-util", "macros", "net", "rt-multi-thread", "time", "sync"] }
1515
tracing = "0.1.41"
1616
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
1717

src/lib.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ use tracing::{debug, warn};
1111

1212
mod key_exchange;
1313
use key_exchange::KeyExchange;
14-
mod proto;
14+
pub mod proto;
1515
use proto::{DecryptingReader, Encode, EncryptingWriter, Packet};
16+
pub mod service;
1617

1718
/// A single SSH connection
1819
pub struct Connection {
@@ -61,7 +62,10 @@ impl Connection {
6162
Ok(self.stream_read.read_packet().await?)
6263
}
6364

64-
pub(crate) async fn send_packet(&mut self, payload: &impl Encode) -> anyhow::Result<()> {
65+
pub(crate) async fn send_packet(
66+
&mut self,
67+
payload: &(impl Encode + ?Sized),
68+
) -> anyhow::Result<()> {
6569
Ok(self.stream_write.write_packet(payload, |_| {}).await?)
6670
}
6771
}

src/proto.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ use tracing::debug;
1010

1111
use crate::Error;
1212

13+
// Message type for the transport layer and key exchange messages.
14+
// Note: this MUST map service messages to the unknown type, otherwise
15+
// the service manager will not work right.
1316
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
1417
pub(crate) enum MessageType {
1518
Disconnect,
@@ -300,7 +303,7 @@ impl<W: AsyncWriteExt + Unpin> EncryptingWriter<W> {
300303
/// Write a packet. Returns written [`Packet`].
301304
pub(crate) async fn write_packet(
302305
&mut self,
303-
payload: &impl Encode,
306+
payload: &(impl Encode + ?Sized),
304307
update_exchange_hash: impl FnOnce(&[u8]),
305308
) -> Result<(), Error> {
306309
self.buf.clear();
@@ -339,8 +342,8 @@ impl<W: AsyncWriteExt + Unpin> EncryptingWriter<W> {
339342
}
340343
}
341344

342-
pub(crate) struct Packet<'a> {
343-
pub(crate) payload: &'a [u8],
345+
pub struct Packet<'a> {
346+
pub payload: &'a [u8],
344347
}
345348

346349
impl<'a> Packet<'a> {
@@ -400,7 +403,10 @@ pub(crate) struct PacketBuilder<'a> {
400403
}
401404

402405
impl<'a> PacketBuilder<'a> {
403-
pub(crate) fn with_payload(self, payload: &impl Encode) -> PacketBuilderWithPayload<'a> {
406+
pub(crate) fn with_payload(
407+
self,
408+
payload: &(impl Encode + ?Sized),
409+
) -> PacketBuilderWithPayload<'a> {
404410
let Self { buf, start } = self;
405411
payload.encode(buf);
406412
PacketBuilderWithPayload { buf, start }
@@ -576,7 +582,7 @@ impl<'a> Decode<'a> for u8 {
576582
}
577583
}
578584

579-
pub(crate) trait Encode {
585+
pub trait Encode {
580586
fn encode(&self, buf: &mut Vec<u8>);
581587
}
582588

src/service.rs

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
use std::borrow::Cow;
2+
3+
use tracing::debug;
4+
5+
use crate::{
6+
proto::{Decode, Decoded, Encode, MessageType, Packet},
7+
Connection,
8+
};
9+
10+
pub trait Service {
11+
fn packet_types(&self) -> &'static [u8];
12+
fn handle_packet(&mut self, packet: Packet<'_>);
13+
}
14+
15+
pub struct ServiceRunner<F> {
16+
services: Vec<Box<dyn Service>>,
17+
outgoing_receiver: tokio::sync::mpsc::UnboundedReceiver<Box<dyn Encode>>,
18+
outgoing_sender: tokio::sync::mpsc::UnboundedSender<Box<dyn Encode>>,
19+
connection: Connection,
20+
service_provider: F,
21+
}
22+
23+
#[derive(Debug, Clone, Copy)]
24+
#[allow(unused)]
25+
enum DisconnectReason {
26+
HostNotAllowedToConnect,
27+
ProtocolError,
28+
KeyExchangeFailed,
29+
Reserved,
30+
MacError,
31+
CompressionError,
32+
ServiceNotAvailable,
33+
ProtocolVersionNotSupported,
34+
HostKeyNotVerifiable,
35+
ConnectionLost,
36+
ByApplication,
37+
TooManyConnections,
38+
AuthCancelledByUser,
39+
NoMoreAuthMethodsAvailable,
40+
IllegalUserName,
41+
Unknown(u32),
42+
}
43+
44+
impl Encode for DisconnectReason {
45+
fn encode(&self, buf: &mut Vec<u8>) {
46+
match self {
47+
Self::HostNotAllowedToConnect => 1,
48+
Self::ProtocolError => 2,
49+
Self::KeyExchangeFailed => 3,
50+
Self::Reserved => 4,
51+
Self::MacError => 5,
52+
Self::CompressionError => 6,
53+
Self::ServiceNotAvailable => 7,
54+
Self::ProtocolVersionNotSupported => 8,
55+
Self::HostKeyNotVerifiable => 9,
56+
Self::ConnectionLost => 10,
57+
Self::ByApplication => 11,
58+
Self::TooManyConnections => 12,
59+
Self::AuthCancelledByUser => 13,
60+
Self::NoMoreAuthMethodsAvailable => 14,
61+
Self::IllegalUserName => 15,
62+
Self::Unknown(v) => *v,
63+
}
64+
.encode(buf);
65+
}
66+
}
67+
68+
struct DisconnectMsg(DisconnectReason);
69+
70+
impl Encode for DisconnectMsg {
71+
fn encode(&self, buf: &mut Vec<u8>) {
72+
MessageType::Disconnect.encode(buf);
73+
self.0.encode(buf);
74+
b"".encode(buf);
75+
b"".encode(buf);
76+
}
77+
}
78+
79+
struct ServiceAcceptMsg<'a> {
80+
name: Cow<'a, [u8]>,
81+
}
82+
83+
impl Encode for ServiceAcceptMsg<'_> {
84+
fn encode(&self, buf: &mut Vec<u8>) {
85+
MessageType::ServiceAccept.encode(buf);
86+
self.name.encode(buf);
87+
}
88+
}
89+
90+
struct UnimplementedMsg {
91+
sequence_no: u32,
92+
}
93+
94+
impl Encode for UnimplementedMsg {
95+
fn encode(&self, buf: &mut Vec<u8>) {
96+
MessageType::Unimplemented.encode(buf);
97+
self.sequence_no.encode(buf);
98+
}
99+
}
100+
101+
impl<
102+
F: FnMut(
103+
&[u8],
104+
tokio::sync::mpsc::UnboundedSender<Box<dyn Encode>>,
105+
) -> Option<Box<dyn Service>>,
106+
> ServiceRunner<F>
107+
{
108+
pub fn new(connection: Connection, service_provider: F) -> Self {
109+
let (outgoing_sender, outgoing_receiver) = tokio::sync::mpsc::unbounded_channel();
110+
Self {
111+
services: vec![],
112+
outgoing_receiver,
113+
outgoing_sender,
114+
connection,
115+
service_provider,
116+
}
117+
}
118+
119+
pub async fn run(mut self) {
120+
enum SelectResult<'a> {
121+
Recv(anyhow::Result<Packet<'a>>),
122+
Send(Option<Box<dyn Encode>>),
123+
}
124+
loop {
125+
let select_result = tokio::select! {
126+
recv = self.connection.recv_packet() => SelectResult::Recv(recv),
127+
send = self.outgoing_receiver.recv() => SelectResult::Send(send),
128+
};
129+
130+
match select_result {
131+
SelectResult::Recv(result) => {
132+
match result {
133+
Ok(packet) => {
134+
match MessageType::decode(packet.payload) {
135+
Ok(Decoded {
136+
value: MessageType::Disconnect,
137+
..
138+
}) => {
139+
return;
140+
}
141+
Ok(Decoded {
142+
value: MessageType::Unknown(v),
143+
..
144+
}) => {
145+
let mut handled = false;
146+
for service in self.services.iter_mut() {
147+
if service.packet_types().contains(&v) {
148+
service.handle_packet(packet);
149+
handled = true;
150+
break;
151+
}
152+
}
153+
if !handled {
154+
// FIXME: send proper packet sequence number
155+
if let Err(e) = self
156+
.connection
157+
.send_packet(&UnimplementedMsg { sequence_no: 0 })
158+
.await
159+
{
160+
debug!("Error sending packet: {e}");
161+
return;
162+
}
163+
}
164+
}
165+
Ok(Decoded {
166+
value: MessageType::ServiceRequest,
167+
next,
168+
}) => {
169+
let service_name = match <&[u8]>::decode(next) {
170+
Ok(Decoded { value, next: &[] }) => value,
171+
Ok(_) => {
172+
debug!("Excess bytes in packet, dropping connection");
173+
if let Err(e) = self
174+
.connection
175+
.send_packet(&DisconnectMsg(
176+
DisconnectReason::ProtocolError,
177+
))
178+
.await
179+
{
180+
debug!("Error sending packet: {e}");
181+
}
182+
return;
183+
}
184+
Err(_) => todo!(),
185+
};
186+
187+
if let Some(service) = (self.service_provider)(
188+
service_name,
189+
self.outgoing_sender.clone(),
190+
) {
191+
self.services.push(service);
192+
let packet = ServiceAcceptMsg {
193+
name: service_name.to_vec().into(),
194+
};
195+
if let Err(e) = self.connection.send_packet(&packet).await {
196+
debug!("Error sending packet: {e}");
197+
return;
198+
}
199+
} else {
200+
debug!(
201+
"Request for unknown service {}",
202+
String::from_utf8_lossy(service_name)
203+
);
204+
if let Err(e) = self
205+
.connection
206+
.send_packet(&DisconnectMsg(
207+
DisconnectReason::ServiceNotAvailable,
208+
))
209+
.await
210+
{
211+
debug!("Error sending packet: {e}");
212+
}
213+
return;
214+
}
215+
}
216+
Ok(_) => {
217+
// FIXME: Figure out what to do with the other known message types instead of ignoring
218+
}
219+
Err(e) => {
220+
debug!("Error decoding packet type: {e}");
221+
if let Err(e) = self
222+
.connection
223+
.send_packet(&DisconnectMsg(
224+
DisconnectReason::ProtocolError,
225+
))
226+
.await
227+
{
228+
debug!("Error sending packet: {e}");
229+
}
230+
return;
231+
}
232+
}
233+
}
234+
Err(e) => {
235+
debug!("Receiving packet failed with error {e}, dropping connection");
236+
return;
237+
}
238+
}
239+
}
240+
SelectResult::Send(Some(payload)) => {
241+
if let Err(e) = self.connection.send_packet(&*payload).await {
242+
debug!("Error sending packet: {e}");
243+
return;
244+
}
245+
}
246+
SelectResult::Send(None) => {}
247+
}
248+
}
249+
}
250+
}

0 commit comments

Comments
 (0)