Skip to content

Commit 0821a9d

Browse files
authored
axum: Allow body types other than axum::body::Body in serve (#3205)
1 parent cd1453f commit 0821a9d

File tree

2 files changed

+78
-23
lines changed

2 files changed

+78
-23
lines changed

axum/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
# Unreleased
9+
10+
- **changed:** `serve` has an additional generic argument and can now work with any response body
11+
type, not just `axum::body::Body` ([#3205])
12+
13+
[#3205]: https://github.com/tokio-rs/axum/pull/3205
14+
815
# 0.8.4
916

1017
- **added:** `Router::reset_fallback` ([#3320])

axum/src/serve/mod.rs

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::{
44
convert::Infallible,
5+
error::Error as StdError,
56
fmt::Debug,
67
future::{Future, IntoFuture},
78
io,
@@ -11,6 +12,7 @@ use std::{
1112

1213
use axum_core::{body::Body, extract::Request, response::Response};
1314
use futures_util::FutureExt;
15+
use http_body::Body as HttpBody;
1416
use hyper::body::Incoming;
1517
use hyper_util::rt::{TokioExecutor, TokioIo};
1618
#[cfg(any(feature = "http1", feature = "http2"))]
@@ -94,12 +96,15 @@ pub use self::listener::{Listener, ListenerExt, TapIo};
9496
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
9597
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
9698
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
97-
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
99+
pub fn serve<L, M, S, B>(listener: L, make_service: M) -> Serve<L, M, S, B>
98100
where
99101
L: Listener,
100102
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
101-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
103+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
102104
S::Future: Send,
105+
B: HttpBody + Send + 'static,
106+
B::Data: Send,
107+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
103108
{
104109
Serve {
105110
listener,
@@ -111,14 +116,14 @@ where
111116
/// Future returned by [`serve`].
112117
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113118
#[must_use = "futures must be awaited or polled"]
114-
pub struct Serve<L, M, S> {
119+
pub struct Serve<L, M, S, B> {
115120
listener: L,
116121
make_service: M,
117-
_marker: PhantomData<S>,
122+
_marker: PhantomData<(S, B)>,
118123
}
119124

120125
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
121-
impl<L, M, S> Serve<L, M, S>
126+
impl<L, M, S, B> Serve<L, M, S, B>
122127
where
123128
L: Listener,
124129
{
@@ -148,7 +153,7 @@ where
148153
///
149154
/// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
150155
/// error. It returns `Ok(())` only after the `signal` future completes.
151-
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
156+
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F, B>
152157
where
153158
F: Future<Output = ()> + Send + 'static,
154159
{
@@ -167,14 +172,17 @@ where
167172
}
168173

169174
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
170-
impl<L, M, S> Serve<L, M, S>
175+
impl<L, M, S, B> Serve<L, M, S, B>
171176
where
172177
L: Listener,
173178
L::Addr: Debug,
174179
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
175180
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
176-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
181+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
177182
S::Future: Send,
183+
B: HttpBody + Send + 'static,
184+
B::Data: Send,
185+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
178186
{
179187
async fn run(self) -> ! {
180188
let Self {
@@ -194,7 +202,7 @@ where
194202
}
195203

196204
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
197-
impl<L, M, S> Debug for Serve<L, M, S>
205+
impl<L, M, S, B> Debug for Serve<L, M, S, B>
198206
where
199207
L: Debug + 'static,
200208
M: Debug,
@@ -215,14 +223,17 @@ where
215223
}
216224

217225
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
218-
impl<L, M, S> IntoFuture for Serve<L, M, S>
226+
impl<L, M, S, B> IntoFuture for Serve<L, M, S, B>
219227
where
220228
L: Listener,
221229
L::Addr: Debug,
222230
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
223231
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
224-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
232+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
225233
S::Future: Send,
234+
B: HttpBody + Send + 'static,
235+
B::Data: Send,
236+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
226237
{
227238
type Output = io::Result<()>;
228239
type IntoFuture = private::ServeFuture;
@@ -235,15 +246,15 @@ where
235246
/// Serve future with graceful shutdown enabled.
236247
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
237248
#[must_use = "futures must be awaited or polled"]
238-
pub struct WithGracefulShutdown<L, M, S, F> {
249+
pub struct WithGracefulShutdown<L, M, S, F, B> {
239250
listener: L,
240251
make_service: M,
241252
signal: F,
242-
_marker: PhantomData<S>,
253+
_marker: PhantomData<(S, B)>,
243254
}
244255

245256
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
246-
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
257+
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
247258
where
248259
L: Listener,
249260
{
@@ -254,15 +265,18 @@ where
254265
}
255266

256267
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
257-
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
268+
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
258269
where
259270
L: Listener,
260271
L::Addr: Debug,
261272
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
262273
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
263-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
274+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
264275
S::Future: Send,
265276
F: Future<Output = ()> + Send + 'static,
277+
B: HttpBody + Send + 'static,
278+
B::Data: Send,
279+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
266280
{
267281
async fn run(self) {
268282
let Self {
@@ -305,7 +319,7 @@ where
305319
}
306320

307321
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
308-
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
322+
impl<L, M, S, F, B> Debug for WithGracefulShutdown<L, M, S, F, B>
309323
where
310324
L: Debug + 'static,
311325
M: Debug,
@@ -329,15 +343,18 @@ where
329343
}
330344

331345
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
332-
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
346+
impl<L, M, S, F, B> IntoFuture for WithGracefulShutdown<L, M, S, F, B>
333347
where
334348
L: Listener,
335349
L::Addr: Debug,
336350
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
337351
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
338-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
352+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
339353
S::Future: Send,
340354
F: Future<Output = ()> + Send + 'static,
355+
B: HttpBody + Send + 'static,
356+
B::Data: Send,
357+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
341358
{
342359
type Output = io::Result<()>;
343360
type IntoFuture = private::ServeFuture;
@@ -350,7 +367,7 @@ where
350367
}
351368
}
352369

353-
async fn handle_connection<L, M, S>(
370+
async fn handle_connection<L, M, S, B>(
354371
make_service: &mut M,
355372
signal_tx: &watch::Sender<()>,
356373
close_rx: &watch::Receiver<()>,
@@ -361,8 +378,11 @@ async fn handle_connection<L, M, S>(
361378
L::Addr: Debug,
362379
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
363380
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
364-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
381+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
365382
S::Future: Send,
383+
B: HttpBody + Send + 'static,
384+
B::Data: Send,
385+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
366386
{
367387
let io = TokioIo::new(io);
368388

@@ -478,14 +498,15 @@ mod tests {
478498
};
479499

480500
use axum_core::{body::Body, extract::Request};
481-
use http::StatusCode;
501+
use http::{Response, StatusCode};
482502
use hyper_util::rt::TokioIo;
483503
#[cfg(unix)]
484504
use tokio::net::UnixListener;
485505
use tokio::{
486506
io::{self, AsyncRead, AsyncWrite},
487507
net::TcpListener,
488508
};
509+
use tower::ServiceBuilder;
489510

490511
#[cfg(unix)]
491512
use super::IncomingStream;
@@ -497,7 +518,7 @@ mod tests {
497518
handler::{Handler, HandlerWithoutStateExt},
498519
routing::get,
499520
serve::ListenerExt,
500-
Router,
521+
Router, ServiceExt,
501522
};
502523

503524
#[allow(dead_code, unused_must_use)]
@@ -725,4 +746,31 @@ mod tests {
725746
let body = String::from_utf8(body.to_vec()).unwrap();
726747
assert_eq!(body, "Hello, World!");
727748
}
749+
750+
#[crate::test]
751+
async fn serving_with_custom_body_type() {
752+
struct CustomBody;
753+
impl http_body::Body for CustomBody {
754+
type Data = bytes::Bytes;
755+
type Error = std::convert::Infallible;
756+
fn poll_frame(
757+
self: std::pin::Pin<&mut Self>,
758+
_cx: &mut std::task::Context<'_>,
759+
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
760+
{
761+
#![allow(clippy::unreachable)] // The implementation is not used, we just need to provide one.
762+
unreachable!();
763+
}
764+
}
765+
766+
let app = ServiceBuilder::new()
767+
.layer_fn(|_| tower::service_fn(|_| std::future::ready(Ok(Response::new(CustomBody)))))
768+
.service(Router::<()>::new().route("/hello", get(|| async {})));
769+
let addr = "0.0.0.0:0";
770+
771+
_ = serve(
772+
TcpListener::bind(addr).await.unwrap(),
773+
app.into_make_service(),
774+
);
775+
}
728776
}

0 commit comments

Comments
 (0)