2
2
3
3
use std:: {
4
4
convert:: Infallible ,
5
+ error:: Error as StdError ,
5
6
fmt:: Debug ,
6
7
future:: { Future , IntoFuture } ,
7
8
io,
@@ -11,6 +12,7 @@ use std::{
11
12
12
13
use axum_core:: { body:: Body , extract:: Request , response:: Response } ;
13
14
use futures_util:: FutureExt ;
15
+ use http_body:: Body as HttpBody ;
14
16
use hyper:: body:: Incoming ;
15
17
use hyper_util:: rt:: { TokioExecutor , TokioIo } ;
16
18
#[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
@@ -94,12 +96,15 @@ pub use self::listener::{Listener, ListenerExt, TapIo};
94
96
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
95
97
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
96
98
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
97
- pub fn serve < L , M , S > ( listener : L , make_service : M ) -> Serve < L , M , S >
99
+ pub fn serve < L , M , S , B > ( listener : L , make_service : M ) -> Serve < L , M , S , B >
98
100
where
99
101
L : Listener ,
100
102
M : for < ' a > Service < IncomingStream < ' a , L > , Error = Infallible , Response = S > ,
101
- S : Service < Request , Response = Response , Error = Infallible > + Clone + Send + ' static ,
103
+ S : Service < Request , Response = Response < B > , Error = Infallible > + Clone + Send + ' static ,
102
104
S :: Future : Send ,
105
+ B : HttpBody + Send + ' static ,
106
+ B :: Data : Send ,
107
+ B :: Error : Into < Box < dyn StdError + Send + Sync > > ,
103
108
{
104
109
Serve {
105
110
listener,
@@ -111,14 +116,14 @@ where
111
116
/// Future returned by [`serve`].
112
117
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
113
118
#[ must_use = "futures must be awaited or polled" ]
114
- pub struct Serve < L , M , S > {
119
+ pub struct Serve < L , M , S , B > {
115
120
listener : L ,
116
121
make_service : M ,
117
- _marker : PhantomData < S > ,
122
+ _marker : PhantomData < ( S , B ) > ,
118
123
}
119
124
120
125
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
121
- impl < L , M , S > Serve < L , M , S >
126
+ impl < L , M , S , B > Serve < L , M , S , B >
122
127
where
123
128
L : Listener ,
124
129
{
@@ -148,7 +153,7 @@ where
148
153
///
149
154
/// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
150
155
/// error. It returns `Ok(())` only after the `signal` future completes.
151
- pub fn with_graceful_shutdown < F > ( self , signal : F ) -> WithGracefulShutdown < L , M , S , F >
156
+ pub fn with_graceful_shutdown < F > ( self , signal : F ) -> WithGracefulShutdown < L , M , S , F , B >
152
157
where
153
158
F : Future < Output = ( ) > + Send + ' static ,
154
159
{
@@ -167,14 +172,17 @@ where
167
172
}
168
173
169
174
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
170
- impl < L , M , S > Serve < L , M , S >
175
+ impl < L , M , S , B > Serve < L , M , S , B >
171
176
where
172
177
L : Listener ,
173
178
L :: Addr : Debug ,
174
179
M : for < ' a > Service < IncomingStream < ' a , L > , Error = Infallible , Response = S > + Send + ' static ,
175
180
for < ' a > <M as Service < IncomingStream < ' a , L > > >:: Future : Send ,
176
- S : Service < Request , Response = Response , Error = Infallible > + Clone + Send + ' static ,
181
+ S : Service < Request , Response = Response < B > , Error = Infallible > + Clone + Send + ' static ,
177
182
S :: Future : Send ,
183
+ B : HttpBody + Send + ' static ,
184
+ B :: Data : Send ,
185
+ B :: Error : Into < Box < dyn StdError + Send + Sync > > ,
178
186
{
179
187
async fn run ( self ) -> ! {
180
188
let Self {
@@ -194,7 +202,7 @@ where
194
202
}
195
203
196
204
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
197
- impl < L , M , S > Debug for Serve < L , M , S >
205
+ impl < L , M , S , B > Debug for Serve < L , M , S , B >
198
206
where
199
207
L : Debug + ' static ,
200
208
M : Debug ,
@@ -215,14 +223,17 @@ where
215
223
}
216
224
217
225
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
218
- impl < L , M , S > IntoFuture for Serve < L , M , S >
226
+ impl < L , M , S , B > IntoFuture for Serve < L , M , S , B >
219
227
where
220
228
L : Listener ,
221
229
L :: Addr : Debug ,
222
230
M : for < ' a > Service < IncomingStream < ' a , L > , Error = Infallible , Response = S > + Send + ' static ,
223
231
for < ' a > <M as Service < IncomingStream < ' a , L > > >:: Future : Send ,
224
- S : Service < Request , Response = Response , Error = Infallible > + Clone + Send + ' static ,
232
+ S : Service < Request , Response = Response < B > , Error = Infallible > + Clone + Send + ' static ,
225
233
S :: Future : Send ,
234
+ B : HttpBody + Send + ' static ,
235
+ B :: Data : Send ,
236
+ B :: Error : Into < Box < dyn StdError + Send + Sync > > ,
226
237
{
227
238
type Output = io:: Result < ( ) > ;
228
239
type IntoFuture = private:: ServeFuture ;
@@ -235,15 +246,15 @@ where
235
246
/// Serve future with graceful shutdown enabled.
236
247
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
237
248
#[ must_use = "futures must be awaited or polled" ]
238
- pub struct WithGracefulShutdown < L , M , S , F > {
249
+ pub struct WithGracefulShutdown < L , M , S , F , B > {
239
250
listener : L ,
240
251
make_service : M ,
241
252
signal : F ,
242
- _marker : PhantomData < S > ,
253
+ _marker : PhantomData < ( S , B ) > ,
243
254
}
244
255
245
256
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
246
- impl < L , M , S , F > WithGracefulShutdown < L , M , S , F >
257
+ impl < L , M , S , F , B > WithGracefulShutdown < L , M , S , F , B >
247
258
where
248
259
L : Listener ,
249
260
{
@@ -254,15 +265,18 @@ where
254
265
}
255
266
256
267
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
257
- impl < L , M , S , F > WithGracefulShutdown < L , M , S , F >
268
+ impl < L , M , S , F , B > WithGracefulShutdown < L , M , S , F , B >
258
269
where
259
270
L : Listener ,
260
271
L :: Addr : Debug ,
261
272
M : for < ' a > Service < IncomingStream < ' a , L > , Error = Infallible , Response = S > + Send + ' static ,
262
273
for < ' a > <M as Service < IncomingStream < ' a , L > > >:: Future : Send ,
263
- S : Service < Request , Response = Response , Error = Infallible > + Clone + Send + ' static ,
274
+ S : Service < Request , Response = Response < B > , Error = Infallible > + Clone + Send + ' static ,
264
275
S :: Future : Send ,
265
276
F : Future < Output = ( ) > + Send + ' static ,
277
+ B : HttpBody + Send + ' static ,
278
+ B :: Data : Send ,
279
+ B :: Error : Into < Box < dyn StdError + Send + Sync > > ,
266
280
{
267
281
async fn run ( self ) {
268
282
let Self {
@@ -305,7 +319,7 @@ where
305
319
}
306
320
307
321
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
308
- impl < L , M , S , F > Debug for WithGracefulShutdown < L , M , S , F >
322
+ impl < L , M , S , F , B > Debug for WithGracefulShutdown < L , M , S , F , B >
309
323
where
310
324
L : Debug + ' static ,
311
325
M : Debug ,
@@ -329,15 +343,18 @@ where
329
343
}
330
344
331
345
#[ cfg( all( feature = "tokio" , any( feature = "http1" , feature = "http2" ) ) ) ]
332
- impl < L , M , S , F > IntoFuture for WithGracefulShutdown < L , M , S , F >
346
+ impl < L , M , S , F , B > IntoFuture for WithGracefulShutdown < L , M , S , F , B >
333
347
where
334
348
L : Listener ,
335
349
L :: Addr : Debug ,
336
350
M : for < ' a > Service < IncomingStream < ' a , L > , Error = Infallible , Response = S > + Send + ' static ,
337
351
for < ' a > <M as Service < IncomingStream < ' a , L > > >:: Future : Send ,
338
- S : Service < Request , Response = Response , Error = Infallible > + Clone + Send + ' static ,
352
+ S : Service < Request , Response = Response < B > , Error = Infallible > + Clone + Send + ' static ,
339
353
S :: Future : Send ,
340
354
F : Future < Output = ( ) > + Send + ' static ,
355
+ B : HttpBody + Send + ' static ,
356
+ B :: Data : Send ,
357
+ B :: Error : Into < Box < dyn StdError + Send + Sync > > ,
341
358
{
342
359
type Output = io:: Result < ( ) > ;
343
360
type IntoFuture = private:: ServeFuture ;
@@ -350,7 +367,7 @@ where
350
367
}
351
368
}
352
369
353
- async fn handle_connection < L , M , S > (
370
+ async fn handle_connection < L , M , S , B > (
354
371
make_service : & mut M ,
355
372
signal_tx : & watch:: Sender < ( ) > ,
356
373
close_rx : & watch:: Receiver < ( ) > ,
@@ -361,8 +378,11 @@ async fn handle_connection<L, M, S>(
361
378
L :: Addr : Debug ,
362
379
M : for < ' a > Service < IncomingStream < ' a , L > , Error = Infallible , Response = S > + Send + ' static ,
363
380
for < ' a > <M as Service < IncomingStream < ' a , L > > >:: Future : Send ,
364
- S : Service < Request , Response = Response , Error = Infallible > + Clone + Send + ' static ,
381
+ S : Service < Request , Response = Response < B > , Error = Infallible > + Clone + Send + ' static ,
365
382
S :: Future : Send ,
383
+ B : HttpBody + Send + ' static ,
384
+ B :: Data : Send ,
385
+ B :: Error : Into < Box < dyn StdError + Send + Sync > > ,
366
386
{
367
387
let io = TokioIo :: new ( io) ;
368
388
@@ -478,14 +498,15 @@ mod tests {
478
498
} ;
479
499
480
500
use axum_core:: { body:: Body , extract:: Request } ;
481
- use http:: StatusCode ;
501
+ use http:: { Response , StatusCode } ;
482
502
use hyper_util:: rt:: TokioIo ;
483
503
#[ cfg( unix) ]
484
504
use tokio:: net:: UnixListener ;
485
505
use tokio:: {
486
506
io:: { self , AsyncRead , AsyncWrite } ,
487
507
net:: TcpListener ,
488
508
} ;
509
+ use tower:: ServiceBuilder ;
489
510
490
511
#[ cfg( unix) ]
491
512
use super :: IncomingStream ;
@@ -497,7 +518,7 @@ mod tests {
497
518
handler:: { Handler , HandlerWithoutStateExt } ,
498
519
routing:: get,
499
520
serve:: ListenerExt ,
500
- Router ,
521
+ Router , ServiceExt ,
501
522
} ;
502
523
503
524
#[ allow( dead_code, unused_must_use) ]
@@ -725,4 +746,31 @@ mod tests {
725
746
let body = String :: from_utf8 ( body. to_vec ( ) ) . unwrap ( ) ;
726
747
assert_eq ! ( body, "Hello, World!" ) ;
727
748
}
749
+
750
+ #[ crate :: test]
751
+ async fn serving_with_custom_body_type ( ) {
752
+ struct CustomBody ;
753
+ impl http_body:: Body for CustomBody {
754
+ type Data = bytes:: Bytes ;
755
+ type Error = std:: convert:: Infallible ;
756
+ fn poll_frame (
757
+ self : std:: pin:: Pin < & mut Self > ,
758
+ _cx : & mut std:: task:: Context < ' _ > ,
759
+ ) -> std:: task:: Poll < Option < Result < http_body:: Frame < Self :: Data > , Self :: Error > > >
760
+ {
761
+ #![ allow( clippy:: unreachable) ] // The implementation is not used, we just need to provide one.
762
+ unreachable ! ( ) ;
763
+ }
764
+ }
765
+
766
+ let app = ServiceBuilder :: new ( )
767
+ . layer_fn ( |_| tower:: service_fn ( |_| std:: future:: ready ( Ok ( Response :: new ( CustomBody ) ) ) ) )
768
+ . service ( Router :: < ( ) > :: new ( ) . route ( "/hello" , get ( || async { } ) ) ) ;
769
+ let addr = "0.0.0.0:0" ;
770
+
771
+ _ = serve (
772
+ TcpListener :: bind ( addr) . await . unwrap ( ) ,
773
+ app. into_make_service ( ) ,
774
+ ) ;
775
+ }
728
776
}
0 commit comments