@@ -10,7 +10,7 @@ use std::thread;
10
10
use futures_util:: { future:: Future , ready} ;
11
11
use rustls:: pki_types:: ServerName ;
12
12
use rustls:: { self , ClientConfig , ServerConnection , Stream } ;
13
- use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWriteExt , ReadBuf } ;
13
+ use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt , ReadBuf } ;
14
14
use tokio:: net:: TcpStream ;
15
15
use tokio_rustls:: client:: TlsStream ;
16
16
use tokio_rustls:: TlsConnector ;
@@ -35,14 +35,15 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
35
35
}
36
36
}
37
37
38
- async fn send (
38
+ async fn send < S : AsyncRead + AsyncWrite + Unpin > (
39
39
config : Arc < ClientConfig > ,
40
40
addr : SocketAddr ,
41
+ wrapper : impl Fn ( TcpStream ) -> S ,
41
42
data : & [ u8 ] ,
42
43
vectored : bool ,
43
- ) -> io:: Result < ( TlsStream < TcpStream > , Vec < u8 > ) > {
44
+ ) -> io:: Result < ( TlsStream < S > , Vec < u8 > ) > {
44
45
let connector = TlsConnector :: from ( config) . early_data ( true ) ;
45
- let stream = TcpStream :: connect ( & addr) . await ?;
46
+ let stream = wrapper ( TcpStream :: connect ( & addr) . await ?) ;
46
47
let domain = ServerName :: try_from ( "foobar.com" ) . unwrap ( ) ;
47
48
48
49
let mut stream = connector. connect ( domain, stream) . await ?;
@@ -58,15 +59,23 @@ async fn send(
58
59
59
60
#[ tokio:: test]
60
61
async fn test_0rtt ( ) -> io:: Result < ( ) > {
61
- test_0rtt_impl ( false ) . await
62
+ test_0rtt_impl ( |s| s , false ) . await
62
63
}
63
64
64
65
#[ tokio:: test]
65
66
async fn test_0rtt_vectored ( ) -> io:: Result < ( ) > {
66
- test_0rtt_impl ( true ) . await
67
+ test_0rtt_impl ( |s| s , true ) . await
67
68
}
68
69
69
- async fn test_0rtt_impl ( vectored : bool ) -> io:: Result < ( ) > {
70
+ #[ tokio:: test]
71
+ async fn test_0rtt_vectored_flush_pending ( ) -> io:: Result < ( ) > {
72
+ test_0rtt_impl ( utils:: FlushWrapper :: new, false ) . await
73
+ }
74
+
75
+ async fn test_0rtt_impl < S : AsyncRead + AsyncWrite + Unpin > (
76
+ wrapper : impl Fn ( TcpStream ) -> S ,
77
+ vectored : bool ,
78
+ ) -> io:: Result < ( ) > {
70
79
let ( mut server, mut client) = utils:: make_configs ( ) ;
71
80
server. max_early_data_size = 8192 ;
72
81
let server = Arc :: new ( server) ;
@@ -108,11 +117,11 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
108
117
let client = Arc :: new ( client) ;
109
118
let addr = SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , server_port) ) ;
110
119
111
- let ( io, buf) = send ( client. clone ( ) , addr, b"hello" , vectored) . await ?;
120
+ let ( io, buf) = send ( client. clone ( ) , addr, & wrapper , b"hello" , vectored) . await ?;
112
121
assert ! ( !io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
113
122
assert_eq ! ( "LATE:hello" , String :: from_utf8_lossy( & buf) ) ;
114
123
115
- let ( io, buf) = send ( client, addr, b"world!" , vectored) . await ?;
124
+ let ( io, buf) = send ( client, addr, wrapper , b"world!" , vectored) . await ?;
116
125
assert ! ( io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
117
126
assert_eq ! ( "EARLY:world!LATE:" , String :: from_utf8_lossy( & buf) ) ;
118
127
0 commit comments