Skip to content

Commit bf4664e

Browse files
committed
migrate TFO to separate tokio-tfo crate
1 parent 6915ce9 commit bf4664e

File tree

9 files changed

+339
-794
lines changed

9 files changed

+339
-794
lines changed

Cargo.lock

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/shadowsocks/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ mio = "0.7"
6060
socket2 = { version = "0.4", features = ["all"] }
6161
tokio = { version = "1.9.0", features = ["io-util", "macros", "net", "parking_lot", "process", "rt", "sync", "time"] }
6262
tokio-io-timeout = "1.1"
63+
tokio-tfo = "0.1.3"
6364

6465
trust-dns-resolver = { version = "0.20", optional = true }
6566
arc-swap = { version = "1.3", optional = true }

crates/shadowsocks/src/net/sys/mod.rs

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -45,86 +45,6 @@ fn set_common_sockopt_for_connect(addr: SocketAddr, socket: &TcpSocket, opts: &C
4545
Ok(())
4646
}
4747

48-
fn set_common_sockopt_after_connect(stream: &tokio::net::TcpStream, opts: &ConnectOpts) -> io::Result<()> {
49-
stream.set_nodelay(opts.tcp.nodelay)?;
50-
set_common_sockopt_after_connect_sys(stream, opts)?;
51-
52-
Ok(())
53-
}
54-
55-
#[cfg(unix)]
56-
#[inline]
57-
fn set_common_sockopt_after_connect_sys(stream: &tokio::net::TcpStream, opts: &ConnectOpts) -> io::Result<()> {
58-
use socket2::{Socket, TcpKeepalive};
59-
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
60-
61-
let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) };
62-
63-
macro_rules! try_sockopt {
64-
($socket:ident . $func:ident ($($arg:expr),*)) => {
65-
match $socket . $func ($($arg),*) {
66-
Ok(e) => e,
67-
Err(err) => {
68-
let _ = socket.into_raw_fd();
69-
return Err(err);
70-
}
71-
}
72-
};
73-
}
74-
75-
if let Some(keepalive_duration) = opts.tcp.keepalive {
76-
#[allow(unused_mut)]
77-
let mut keepalive = TcpKeepalive::new().with_time(keepalive_duration);
78-
79-
#[cfg(any(
80-
target_os = "freebsd",
81-
target_os = "fuchsia",
82-
target_os = "linux",
83-
target_os = "netbsd",
84-
target_vendor = "apple",
85-
))]
86-
{
87-
keepalive = keepalive.with_interval(keepalive_duration);
88-
}
89-
90-
try_sockopt!(socket.set_tcp_keepalive(&keepalive));
91-
}
92-
93-
let _ = socket.into_raw_fd();
94-
Ok(())
95-
}
96-
97-
#[cfg(windows)]
98-
#[inline]
99-
fn set_common_sockopt_after_connect_sys(stream: &tokio::net::TcpStream, opts: &ConnectOpts) -> io::Result<()> {
100-
use socket2::{Socket, TcpKeepalive};
101-
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket};
102-
103-
let socket = unsafe { Socket::from_raw_socket(stream.as_raw_socket()) };
104-
105-
macro_rules! try_sockopt {
106-
($socket:ident . $func:ident ($($arg:expr),*)) => {
107-
match $socket . $func ($($arg),*) {
108-
Ok(e) => e,
109-
Err(err) => {
110-
let _ = socket.into_raw_socket();
111-
return Err(err);
112-
}
113-
}
114-
};
115-
}
116-
117-
if let Some(keepalive_duration) = opts.tcp.keepalive {
118-
let keepalive = TcpKeepalive::new()
119-
.with_time(keepalive_duration)
120-
.with_interval(keepalive_duration);
121-
try_sockopt!(socket.set_tcp_keepalive(&keepalive));
122-
}
123-
124-
let _ = socket.into_raw_socket();
125-
Ok(())
126-
}
127-
12848
#[cfg(all(not(windows), not(unix)))]
12949
#[inline]
13050
fn set_common_sockopt_after_connect_sys(_: &tokio::net::TcpStream, _: &ConnectOpts) -> io::Result<()> {

crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs

Lines changed: 58 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,31 @@
11
use std::{
2-
io::{self, ErrorKind},
2+
io,
33
mem,
4-
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpStream as StdTcpStream},
5-
ops::{Deref, DerefMut},
6-
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd},
4+
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
5+
os::unix::io::{AsRawFd, RawFd},
76
pin::Pin,
87
task::{self, Poll},
98
};
109

11-
use futures::ready;
1210
use log::error;
1311
use pin_project::pin_project;
14-
use socket2::SockAddr;
1512
use tokio::{
16-
io::{AsyncRead, AsyncWrite, Interest, ReadBuf},
13+
io::{AsyncRead, AsyncWrite, ReadBuf},
1714
net::{TcpSocket, TcpStream as TokioTcpStream, UdpSocket},
1815
};
16+
use tokio_tfo::TfoStream;
1917

2018
use crate::net::{
2119
sys::{set_common_sockopt_after_connect, set_common_sockopt_for_connect},
2220
AddrFamily,
2321
ConnectOpts,
2422
};
2523

26-
enum TcpStreamState {
27-
Connected,
28-
FastOpenConnect(SocketAddr),
29-
}
30-
3124
/// A `TcpStream` that supports TFO (TCP Fast Open)
3225
#[pin_project(project = TcpStreamProj)]
33-
pub struct TcpStream {
34-
#[pin]
35-
inner: TokioTcpStream,
36-
state: TcpStreamState,
26+
pub enum TcpStream {
27+
Standard(#[pin] TokioTcpStream),
28+
FastOpen(#[pin] TfoStream),
3729
}
3830

3931
impl TcpStream {
@@ -50,142 +42,82 @@ impl TcpStream {
5042
let stream = socket.connect(addr).await?;
5143
set_common_sockopt_after_connect(&stream, opts)?;
5244

53-
return Ok(TcpStream {
54-
inner: stream,
55-
state: TcpStreamState::Connected,
56-
});
45+
return Ok(TcpStream::Standard(stream));
5746
}
5847

59-
unsafe {
60-
let enable: libc::c_int = 1;
61-
62-
let ret = libc::setsockopt(
63-
socket.as_raw_fd(),
64-
libc::IPPROTO_TCP,
65-
libc::TCP_FASTOPEN,
66-
&enable as *const _ as *const libc::c_void,
67-
mem::size_of_val(&enable) as libc::socklen_t,
68-
);
69-
70-
if ret != 0 {
71-
let err = io::Error::last_os_error();
72-
error!("set TCP_FASTOPEN error: {}", err);
73-
return Err(err);
74-
}
75-
}
76-
77-
let stream = TokioTcpStream::from_std(unsafe { StdTcpStream::from_raw_fd(socket.into_raw_fd()) })?;
48+
let stream = TfoStream::connect_with_socket(socket, addr).await?;
7849
set_common_sockopt_after_connect(&stream, opts)?;
7950

80-
Ok(TcpStream {
81-
inner: stream,
82-
state: TcpStreamState::FastOpenConnect(addr),
83-
})
51+
Ok(TcpStream::FastOpen(stream))
52+
}
53+
54+
pub fn local_addr(&self) -> io::Result<SocketAddr> {
55+
match *self {
56+
TcpStream::Standard(ref s) => s.local_addr(),
57+
TcpStream::FastOpen(ref s) => s.local_addr(),
58+
}
59+
}
60+
61+
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
62+
match *self {
63+
TcpStream::Standard(ref s) => s.peer_addr(),
64+
TcpStream::FastOpen(ref s) => s.peer_addr(),
65+
}
8466
}
85-
}
8667

87-
impl Deref for TcpStream {
88-
type Target = TokioTcpStream;
68+
pub fn nodelay(&self) -> io::Result<bool> {
69+
match *self {
70+
TcpStream::Standard(ref s) => s.nodelay(),
71+
TcpStream::FastOpen(ref s) => s.nodelay(),
72+
}
73+
}
8974

90-
fn deref(&self) -> &Self::Target {
91-
&self.inner
75+
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
76+
match *self {
77+
TcpStream::Standard(ref s) => s.set_nodelay(nodelay),
78+
TcpStream::FastOpen(ref s) => s.set_nodelay(nodelay),
79+
}
9280
}
9381
}
9482

95-
impl DerefMut for TcpStream {
96-
fn deref_mut(&mut self) -> &mut Self::Target {
97-
&mut self.inner
83+
impl AsRawFd for TcpStream {
84+
fn as_raw_fd(&self) -> RawFd {
85+
match *self {
86+
TcpStream::Standard(ref s) => s.as_raw_fd(),
87+
TcpStream::FastOpen(ref s) => s.as_raw_fd(),
88+
}
9889
}
9990
}
10091

10192
impl AsyncRead for TcpStream {
10293
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
103-
self.project().inner.poll_read(cx, buf)
94+
match self.project() {
95+
TcpStreamProj::Standard(s) => s.poll_read(cx, buf),
96+
TcpStreamProj::FastOpen(s) => s.poll_read(cx, buf),
97+
}
10498
}
10599
}
106100

107101
impl AsyncWrite for TcpStream {
108-
fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
109-
loop {
110-
let TcpStreamProj { inner, state } = self.project();
111-
112-
match *state {
113-
TcpStreamState::Connected => return inner.poll_write(cx, buf),
114-
115-
TcpStreamState::FastOpenConnect(addr) => {
116-
// TCP_FASTOPEN was supported since FreeBSD 12.0
117-
//
118-
// Example program:
119-
// <https://people.freebsd.org/~pkelsey/tfo-tools/tfo-client.c>
120-
121-
let saddr = SockAddr::from(addr);
122-
123-
let stream = inner.get_mut();
124-
125-
// Ensure socket is writable
126-
ready!(stream.poll_write_ready(cx))?;
127-
128-
let mut connecting = false;
129-
let send_result = stream.try_io(Interest::WRITABLE, || {
130-
unsafe {
131-
let ret = libc::sendto(
132-
stream.as_raw_fd(),
133-
buf.as_ptr() as *const libc::c_void,
134-
buf.len(),
135-
0, // Yes, BSD doesn't need MSG_FASTOPEN
136-
saddr.as_ptr(),
137-
saddr.len(),
138-
);
139-
140-
if ret >= 0 {
141-
Ok(ret as usize)
142-
} else {
143-
// Error occurs
144-
let err = io::Error::last_os_error();
145-
146-
// EINPROGRESS
147-
if let Some(libc::EINPROGRESS) = err.raw_os_error() {
148-
// For non-blocking socket, it returns the number of bytes queued (and transmitted in the SYN-data packet) if cookie is available.
149-
// If cookie is not available, it transmits a data-less SYN packet with Fast Open cookie request option and returns -EINPROGRESS like connect().
150-
//
151-
// So in this state. We have to loop again to call `poll_write` for sending the first packet.
152-
connecting = true;
153-
154-
// Let `poll_write_io` clears the write readiness.
155-
Err(ErrorKind::WouldBlock.into())
156-
} else {
157-
// Other errors, including EAGAIN, EWOULDBLOCK
158-
Err(err)
159-
}
160-
}
161-
}
162-
});
163-
164-
match send_result {
165-
Ok(n) => {
166-
// Connected successfully with fast open
167-
*state = TcpStreamState::Connected;
168-
return Ok(n).into();
169-
}
170-
Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
171-
if connecting {
172-
// Connecting with normal TCP handshakes, write the first packet after connected
173-
*state = TcpStreamState::Connected;
174-
}
175-
}
176-
Err(err) => return Err(err).into(),
177-
}
178-
}
179-
}
102+
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
103+
match self.project() {
104+
TcpStreamProj::Standard(s) => s.poll_write(cx, buf),
105+
TcpStreamProj::FastOpen(s) => s.poll_write(cx, buf),
180106
}
181107
}
182108

183109
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
184-
self.project().inner.poll_flush(cx)
110+
match self.project() {
111+
TcpStreamProj::Standard(s) => s.poll_flush(cx),
112+
TcpStreamProj::FastOpen(s) => s.poll_flush(cx),
113+
}
185114
}
186115

187116
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
188-
self.project().inner.poll_shutdown(cx)
117+
match self.project() {
118+
TcpStreamProj::Standard(s) => s.poll_shutdown(cx),
119+
TcpStreamProj::FastOpen(s) => s.poll_shutdown(cx),
120+
}
189121
}
190122
}
191123

0 commit comments

Comments
 (0)