Skip to content

Commit 3a62c41

Browse files
committed
Add handshake timeout, and minor buffer initialization fix
- Add `authenticate_with_timeout` to `IncomingConnection` for handshake timeouts - Change `authenticate` to return `crate::Result` and wrap unsupported handshake error with `crate::Error::Io` - Use precomputed `len` when creating `BytesMut` in `StreamOperation::write_to_stream`
1 parent d6e6f6f commit 3a62c41

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/protocol/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ pub trait StreamOperation {
6464
Self: Sized;
6565

6666
fn write_to_stream<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
67-
let mut buf = bytes::BytesMut::with_capacity(self.len());
67+
let len = self.len();
68+
let mut buf = bytes::BytesMut::with_capacity(len);
6869
self.write_to_buf(&mut buf);
6970
w.write_all(&buf)
7071
}

src/server/connection/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ impl<O: 'static> IncomingConnection<O> {
9494
self.stream.set_ttl(ttl)
9595
}
9696

97+
/// Set a timeout for the SOCKS5 handshake.
98+
pub async fn authenticate_with_timeout(self, timeout: Duration) -> crate::Result<(Authenticated, O)> {
99+
tokio::time::timeout(timeout, self.authenticate())
100+
.await
101+
.map_err(|_| crate::Error::String("handshake timeout".into()))?
102+
}
103+
97104
/// Perform a SOCKS5 authentication handshake using the given
98105
/// [`AuthExecutor`](crate::server::auth::AuthExecutor) adapter.
99106
///
@@ -102,7 +109,7 @@ impl<O: 'static> IncomingConnection<O> {
102109
/// Otherwise, the error and the original [`TcpStream`](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html) is returned.
103110
///
104111
/// Note that this method will not implicitly close the connection even if the handshake failed.
105-
pub async fn authenticate(mut self) -> std::io::Result<(Authenticated, O)> {
112+
pub async fn authenticate(mut self) -> crate::Result<(Authenticated, O)> {
106113
let request = handshake::Request::retrieve_from_async_stream(&mut self.stream).await?;
107114
if let Some(method) = self.evaluate_request(&request) {
108115
let response = handshake::Response::new(method);
@@ -113,7 +120,7 @@ impl<O: 'static> IncomingConnection<O> {
113120
let response = handshake::Response::new(AuthMethod::NoAcceptableMethods);
114121
response.write_to_async_stream(&mut self.stream).await?;
115122
let err = "No available handshake method provided by client";
116-
Err(std::io::Error::new(std::io::ErrorKind::Unsupported, err))
123+
Err(crate::Error::Io(std::io::Error::new(std::io::ErrorKind::Unsupported, err)))
117124
}
118125
}
119126

0 commit comments

Comments
 (0)