diff --git a/Cargo.lock b/Cargo.lock index 23604016d8..0625659947 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -390,6 +390,7 @@ dependencies = [ "hyper 1.5.2", "mime", "multer", + "paste", "percent-encoding", "pin-project-lite", "prost", diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 64b4ea9597..4ad105e350 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -57,6 +57,7 @@ http = "1.0.0" http-body = "1.0.0" http-body-util = "0.1.0" mime = "0.3" +paste = "1.0" pin-project-lite = "0.2" rustversion = "1.0.9" serde = "1.0" diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index 5d0e33eb7c..a58ba5e04e 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -1,6 +1,6 @@ //! `Either*` types for combining extractors or responses into a single type. //! -//! # As an extractor +//! # As an `FromRequestParts` extractor //! //! ``` //! use axum_extra::either::Either3; @@ -54,6 +54,42 @@ //! Note that if all the inner extractors reject the request, the rejection from the last //! extractor will be returned. For the example above that would be [`BytesRejection`]. //! +//! # As an `FromRequest` extractor +//! +//! In the following example, we can first try to deserialize the payload as JSON, if that fails try +//! to interpret it as a UTF-8 string, and lastly just take the raw bytes. +//! +//! It might be preferable to instead extract `Bytes` directly and then fallibly convert them to +//! `String` and then deserialize the data inside the handler. +//! +//! ``` +//! use axum_extra::either::Either3; +//! use axum::{ +//! body::Bytes, +//! Json, +//! Router, +//! routing::get, +//! extract::FromRequestParts, +//! }; +//! +//! #[derive(serde::Deserialize)] +//! struct Payload { +//! user: String, +//! request_id: u32, +//! } +//! +//! async fn handler( +//! body: Either3, String, Bytes>, +//! ) { +//! match body { +//! Either3::E1(json) => { /* ... */ } +//! Either3::E2(string) => { /* ... */ } +//! Either3::E3(bytes) => { /* ... */ } +//! } +//! } +//! # +//! # let _: axum::routing::MethodRouter = axum::routing::get(handler); +//! ``` //! # As a response //! //! ``` @@ -93,17 +129,19 @@ use std::task::{Context, Poll}; use axum::{ - extract::FromRequestParts, + extract::{rejection::BytesRejection, FromRequest, FromRequestParts, Request}, response::{IntoResponse, Response}, }; +use bytes::Bytes; use http::request::Parts; +use paste::paste; use tower_layer::Layer; use tower_service::Service; /// Combines two extractors or responses into a single type. /// /// See the [module docs](self) for examples. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] #[must_use] pub enum Either { #[allow(missing_docs)] @@ -226,6 +264,28 @@ pub enum Either8 { E8(E8), } +/// Rejection used for [`Either`], [`Either3`], etc. +/// +/// Contains one variant for a case when the whole request could not be loaded and one variant +/// containing the rejection of the last variant if all extractors failed.. +#[derive(Debug)] +pub enum EitherRejection { + /// Buffering of the request body failed. + Bytes(BytesRejection), + + /// All extractors failed. This contains the error returned by the last extractor. + LastRejection(E), +} + +impl IntoResponse for EitherRejection { + fn into_response(self) -> Response { + match self { + EitherRejection::Bytes(rejection) => rejection.into_response(), + EitherRejection::LastRejection(rejection) => rejection.into_response(), + } + } +} + macro_rules! impl_traits_for_either { ( $either:ident => @@ -251,6 +311,45 @@ macro_rules! impl_traits_for_either { } } + paste! { + impl]),*, [<$last Via>]> FromRequest]),*, [<$last Via>])> for $either<$($ident),*, $last> + where + S: Send + Sync, + $($ident: FromRequest]>),*, + $last: FromRequest]>, + $($ident::Rejection: Send),*, + $last::Rejection: IntoResponse + Send, + { + type Rejection = EitherRejection<$last::Rejection>; + + async fn from_request(req: Request, state: &S) -> Result { + let (parts, body) = req.into_parts(); + let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state) + .await + .map_err(EitherRejection::Bytes)?; + + $( + let req = Request::from_parts( + parts.clone(), + axum::body::Body::new(http_body_util::Full::new(bytes.clone())), + ); + if let Ok(extracted) = $ident::from_request(req, state).await { + return Ok(Self::$ident(extracted)); + } + )* + + let req = Request::from_parts( + parts.clone(), + axum::body::Body::new(http_body_util::Full::new(bytes.clone())), + ); + match $last::from_request(req, state).await { + Ok(extracted) => Ok(Self::$last(extracted)), + Err(error) => Err(EitherRejection::LastRejection(error)), + } + } + } + } + impl<$($ident),*, $last> IntoResponse for $either<$($ident),*, $last> where $($ident: IntoResponse),*, @@ -312,3 +411,79 @@ where } } } + +#[cfg(test)] +mod tests { + use std::future::Future; + + use axum::body::Body; + use axum::extract::rejection::StringRejection; + use axum::extract::{FromRequest, Request, State}; + use bytes::Bytes; + use http_body_util::Full; + + use super::*; + + #[derive(Debug, PartialEq)] + struct False; + + impl FromRequestParts for False { + type Rejection = (); + + fn from_request_parts( + _parts: &mut Parts, + _state: &S, + ) -> impl Future> + Send { + std::future::ready(Err(())) + } + } + + #[tokio::test] + async fn either_from_request() { + // The body is by design not valid UTF-8. + let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255])))); + + let either = Either4::::from_request(request, &()) + .await + .unwrap(); + + assert!(matches!(either, Either4::E3(_))); + } + + #[tokio::test] + async fn either_from_request_rejection() { + // The body is by design not valid UTF-8. + let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255])))); + + let either = Either::::from_request(request, &()) + .await + .unwrap_err(); + + assert!(matches!( + either, + EitherRejection::LastRejection(StringRejection::InvalidUtf8(_)) + )); + } + + #[tokio::test] + async fn either_from_request_parts() { + let (mut parts, _) = Request::new(Body::empty()).into_parts(); + + let either = Either3::>::from_request_parts(&mut parts, &()) + .await + .unwrap(); + + assert!(matches!(either, Either3::E3(State(())))); + } + + #[tokio::test] + async fn either_from_request_or_parts() { + let request = Request::new(Body::empty()); + + let either = Either::::from_request(request, &()) + .await + .unwrap(); + + assert_eq!(either, Either::E2(Bytes::new())); + } +}