1
1
use std:: {
2
- io:: { self , ErrorKind } ,
2
+ io,
3
3
mem,
4
- net:: { IpAddr , Ipv4Addr , Ipv6Addr , SocketAddr , TcpStream as StdTcpStream } ,
5
- ops:: { Deref , DerefMut } ,
6
- os:: unix:: io:: { AsRawFd , FromRawFd , IntoRawFd } ,
4
+ net:: { IpAddr , Ipv4Addr , Ipv6Addr , SocketAddr } ,
5
+ os:: unix:: io:: { AsRawFd , RawFd } ,
7
6
pin:: Pin ,
8
7
task:: { self , Poll } ,
9
8
} ;
10
9
11
- use futures:: ready;
12
10
use log:: error;
13
11
use pin_project:: pin_project;
14
- use socket2:: SockAddr ;
15
12
use tokio:: {
16
- io:: { AsyncRead , AsyncWrite , Interest , ReadBuf } ,
13
+ io:: { AsyncRead , AsyncWrite , ReadBuf } ,
17
14
net:: { TcpSocket , TcpStream as TokioTcpStream , UdpSocket } ,
18
15
} ;
16
+ use tokio_tfo:: TfoStream ;
19
17
20
18
use crate :: net:: {
21
19
sys:: { set_common_sockopt_after_connect, set_common_sockopt_for_connect} ,
22
20
AddrFamily ,
23
21
ConnectOpts ,
24
22
} ;
25
23
26
- enum TcpStreamState {
27
- Connected ,
28
- FastOpenConnect ( SocketAddr ) ,
29
- }
30
-
31
24
/// A `TcpStream` that supports TFO (TCP Fast Open)
32
25
#[ pin_project( project = TcpStreamProj ) ]
33
- pub struct TcpStream {
34
- #[ pin]
35
- inner : TokioTcpStream ,
36
- state : TcpStreamState ,
26
+ pub enum TcpStream {
27
+ Standard ( #[ pin] TokioTcpStream ) ,
28
+ FastOpen ( #[ pin] TfoStream ) ,
37
29
}
38
30
39
31
impl TcpStream {
@@ -50,142 +42,82 @@ impl TcpStream {
50
42
let stream = socket. connect ( addr) . await ?;
51
43
set_common_sockopt_after_connect ( & stream, opts) ?;
52
44
53
- return Ok ( TcpStream {
54
- inner : stream,
55
- state : TcpStreamState :: Connected ,
56
- } ) ;
45
+ return Ok ( TcpStream :: Standard ( stream) ) ;
57
46
}
58
47
59
- unsafe {
60
- let enable: libc:: c_int = 1 ;
61
-
62
- let ret = libc:: setsockopt (
63
- socket. as_raw_fd ( ) ,
64
- libc:: IPPROTO_TCP ,
65
- libc:: TCP_FASTOPEN ,
66
- & enable as * const _ as * const libc:: c_void ,
67
- mem:: size_of_val ( & enable) as libc:: socklen_t ,
68
- ) ;
69
-
70
- if ret != 0 {
71
- let err = io:: Error :: last_os_error ( ) ;
72
- error ! ( "set TCP_FASTOPEN error: {}" , err) ;
73
- return Err ( err) ;
74
- }
75
- }
76
-
77
- let stream = TokioTcpStream :: from_std ( unsafe { StdTcpStream :: from_raw_fd ( socket. into_raw_fd ( ) ) } ) ?;
48
+ let stream = TfoStream :: connect_with_socket ( socket, addr) . await ?;
78
49
set_common_sockopt_after_connect ( & stream, opts) ?;
79
50
80
- Ok ( TcpStream {
81
- inner : stream,
82
- state : TcpStreamState :: FastOpenConnect ( addr) ,
83
- } )
51
+ Ok ( TcpStream :: FastOpen ( stream) )
52
+ }
53
+
54
+ pub fn local_addr ( & self ) -> io:: Result < SocketAddr > {
55
+ match * self {
56
+ TcpStream :: Standard ( ref s) => s. local_addr ( ) ,
57
+ TcpStream :: FastOpen ( ref s) => s. local_addr ( ) ,
58
+ }
59
+ }
60
+
61
+ pub fn peer_addr ( & self ) -> io:: Result < SocketAddr > {
62
+ match * self {
63
+ TcpStream :: Standard ( ref s) => s. peer_addr ( ) ,
64
+ TcpStream :: FastOpen ( ref s) => s. peer_addr ( ) ,
65
+ }
84
66
}
85
- }
86
67
87
- impl Deref for TcpStream {
88
- type Target = TokioTcpStream ;
68
+ pub fn nodelay ( & self ) -> io:: Result < bool > {
69
+ match * self {
70
+ TcpStream :: Standard ( ref s) => s. nodelay ( ) ,
71
+ TcpStream :: FastOpen ( ref s) => s. nodelay ( ) ,
72
+ }
73
+ }
89
74
90
- fn deref ( & self ) -> & Self :: Target {
91
- & self . inner
75
+ pub fn set_nodelay ( & self , nodelay : bool ) -> io:: Result < ( ) > {
76
+ match * self {
77
+ TcpStream :: Standard ( ref s) => s. set_nodelay ( nodelay) ,
78
+ TcpStream :: FastOpen ( ref s) => s. set_nodelay ( nodelay) ,
79
+ }
92
80
}
93
81
}
94
82
95
- impl DerefMut for TcpStream {
96
- fn deref_mut ( & mut self ) -> & mut Self :: Target {
97
- & mut self . inner
83
+ impl AsRawFd for TcpStream {
84
+ fn as_raw_fd ( & self ) -> RawFd {
85
+ match * self {
86
+ TcpStream :: Standard ( ref s) => s. as_raw_fd ( ) ,
87
+ TcpStream :: FastOpen ( ref s) => s. as_raw_fd ( ) ,
88
+ }
98
89
}
99
90
}
100
91
101
92
impl AsyncRead for TcpStream {
102
93
fn poll_read ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < io:: Result < ( ) > > {
103
- self . project ( ) . inner . poll_read ( cx, buf)
94
+ match self . project ( ) {
95
+ TcpStreamProj :: Standard ( s) => s. poll_read ( cx, buf) ,
96
+ TcpStreamProj :: FastOpen ( s) => s. poll_read ( cx, buf) ,
97
+ }
104
98
}
105
99
}
106
100
107
101
impl AsyncWrite for TcpStream {
108
- fn poll_write ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & [ u8 ] ) -> Poll < io:: Result < usize > > {
109
- loop {
110
- let TcpStreamProj { inner, state } = self . project ( ) ;
111
-
112
- match * state {
113
- TcpStreamState :: Connected => return inner. poll_write ( cx, buf) ,
114
-
115
- TcpStreamState :: FastOpenConnect ( addr) => {
116
- // TCP_FASTOPEN was supported since FreeBSD 12.0
117
- //
118
- // Example program:
119
- // <https://people.freebsd.org/~pkelsey/tfo-tools/tfo-client.c>
120
-
121
- let saddr = SockAddr :: from ( addr) ;
122
-
123
- let stream = inner. get_mut ( ) ;
124
-
125
- // Ensure socket is writable
126
- ready ! ( stream. poll_write_ready( cx) ) ?;
127
-
128
- let mut connecting = false ;
129
- let send_result = stream. try_io ( Interest :: WRITABLE , || {
130
- unsafe {
131
- let ret = libc:: sendto (
132
- stream. as_raw_fd ( ) ,
133
- buf. as_ptr ( ) as * const libc:: c_void ,
134
- buf. len ( ) ,
135
- 0 , // Yes, BSD doesn't need MSG_FASTOPEN
136
- saddr. as_ptr ( ) ,
137
- saddr. len ( ) ,
138
- ) ;
139
-
140
- if ret >= 0 {
141
- Ok ( ret as usize )
142
- } else {
143
- // Error occurs
144
- let err = io:: Error :: last_os_error ( ) ;
145
-
146
- // EINPROGRESS
147
- if let Some ( libc:: EINPROGRESS ) = err. raw_os_error ( ) {
148
- // For non-blocking socket, it returns the number of bytes queued (and transmitted in the SYN-data packet) if cookie is available.
149
- // If cookie is not available, it transmits a data-less SYN packet with Fast Open cookie request option and returns -EINPROGRESS like connect().
150
- //
151
- // So in this state. We have to loop again to call `poll_write` for sending the first packet.
152
- connecting = true ;
153
-
154
- // Let `poll_write_io` clears the write readiness.
155
- Err ( ErrorKind :: WouldBlock . into ( ) )
156
- } else {
157
- // Other errors, including EAGAIN, EWOULDBLOCK
158
- Err ( err)
159
- }
160
- }
161
- }
162
- } ) ;
163
-
164
- match send_result {
165
- Ok ( n) => {
166
- // Connected successfully with fast open
167
- * state = TcpStreamState :: Connected ;
168
- return Ok ( n) . into ( ) ;
169
- }
170
- Err ( ref err) if err. kind ( ) == ErrorKind :: WouldBlock => {
171
- if connecting {
172
- // Connecting with normal TCP handshakes, write the first packet after connected
173
- * state = TcpStreamState :: Connected ;
174
- }
175
- }
176
- Err ( err) => return Err ( err) . into ( ) ,
177
- }
178
- }
179
- }
102
+ fn poll_write ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & [ u8 ] ) -> Poll < io:: Result < usize > > {
103
+ match self . project ( ) {
104
+ TcpStreamProj :: Standard ( s) => s. poll_write ( cx, buf) ,
105
+ TcpStreamProj :: FastOpen ( s) => s. poll_write ( cx, buf) ,
180
106
}
181
107
}
182
108
183
109
fn poll_flush ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
184
- self . project ( ) . inner . poll_flush ( cx)
110
+ match self . project ( ) {
111
+ TcpStreamProj :: Standard ( s) => s. poll_flush ( cx) ,
112
+ TcpStreamProj :: FastOpen ( s) => s. poll_flush ( cx) ,
113
+ }
185
114
}
186
115
187
116
fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
188
- self . project ( ) . inner . poll_shutdown ( cx)
117
+ match self . project ( ) {
118
+ TcpStreamProj :: Standard ( s) => s. poll_shutdown ( cx) ,
119
+ TcpStreamProj :: FastOpen ( s) => s. poll_shutdown ( cx) ,
120
+ }
189
121
}
190
122
}
191
123
0 commit comments