Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ http = "1.0.0"
http-body = "1.0.0"
http-body-util = "0.1.0"
mime = "0.3"
paste = "1.0"
pin-project-lite = "0.2"
rustversion = "1.0.9"
serde = "1.0"
Expand Down
181 changes: 178 additions & 3 deletions axum-extra/src/either.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! `Either*` types for combining extractors or responses into a single type.
//!
//! # As an extractor
//! # As an `FromRequestParts` extractor
//!
//! ```
//! use axum_extra::either::Either3;
Expand Down Expand Up @@ -54,6 +54,42 @@
//! Note that if all the inner extractors reject the request, the rejection from the last
//! extractor will be returned. For the example above that would be [`BytesRejection`].
//!
//! # As an `FromRequest` extractor
//!
//! In the following example, we can first try to deserialize the payload as JSON, if that fails try
//! to interpret it as a UTF-8 string, and lastly just take the raw bytes.
//!
//! It might be preferable to instead extract `Bytes` directly and then fallibly convert them to
//! `String` and then deserialize the data inside the handler.
//!
//! ```
//! use axum_extra::either::Either3;
//! use axum::{
//! body::Bytes,
//! Json,
//! Router,
//! routing::get,
//! extract::FromRequestParts,
//! };
//!
//! #[derive(serde::Deserialize)]
//! struct Payload {
//! user: String,
//! request_id: u32,
//! }
//!
//! async fn handler(
//! body: Either3<Json<Payload>, String, Bytes>,
//! ) {
//! match body {
//! Either3::E1(json) => { /* ... */ }
//! Either3::E2(string) => { /* ... */ }
//! Either3::E3(bytes) => { /* ... */ }
//! }
//! }
//! #
//! # let _: axum::routing::MethodRouter = axum::routing::get(handler);
//! ```
//! # As a response
//!
//! ```
Expand Down Expand Up @@ -93,17 +129,19 @@
use std::task::{Context, Poll};

use axum::{
extract::FromRequestParts,
extract::{rejection::BytesRejection, FromRequest, FromRequestParts, Request},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use http::request::Parts;
use paste::paste;
use tower_layer::Layer;
use tower_service::Service;

/// Combines two extractors or responses into a single type.
///
/// See the [module docs](self) for examples.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[must_use]
pub enum Either<E1, E2> {
#[allow(missing_docs)]
Expand Down Expand Up @@ -226,6 +264,28 @@ pub enum Either8<E1, E2, E3, E4, E5, E6, E7, E8> {
E8(E8),
}

/// Rejection used for [`Either`], [`Either3`], etc.
///
/// Contains one variant for a case when the whole request could not be loaded and one variant
/// containing the rejection of the last variant if all extractors failed..
#[derive(Debug)]
pub enum EitherRejection<E> {
/// Buffering of the request body failed.
Bytes(BytesRejection),

/// All extractors failed. This contains the error returned by the last extractor.
LastRejection(E),
}

impl<E: IntoResponse> IntoResponse for EitherRejection<E> {
fn into_response(self) -> Response {
match self {
EitherRejection::Bytes(rejection) => rejection.into_response(),
EitherRejection::LastRejection(rejection) => rejection.into_response(),
}
}
}

macro_rules! impl_traits_for_either {
(
$either:ident =>
Expand All @@ -251,6 +311,45 @@ macro_rules! impl_traits_for_either {
}
}

paste! {
impl<S, $($ident),*, $last, $([< $ident Via >]),*, [<$last Via>]> FromRequest<S, ($([< $ident Via >]),*, [<$last Via>])> for $either<$($ident),*, $last>
where
S: Send + Sync,
$($ident: FromRequest<S, [<$ident Via>]>),*,
$last: FromRequest<S, [<$last Via>]>,
$($ident::Rejection: Send),*,
$last::Rejection: IntoResponse + Send,
{
type Rejection = EitherRejection<$last::Rejection>;

async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state)
.await
.map_err(EitherRejection::Bytes)?;

$(
let req = Request::from_parts(
parts.clone(),
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
);
if let Ok(extracted) = $ident::from_request(req, state).await {
return Ok(Self::$ident(extracted));
}
)*

let req = Request::from_parts(
parts.clone(),
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
);
match $last::from_request(req, state).await {
Ok(extracted) => Ok(Self::$last(extracted)),
Err(error) => Err(EitherRejection::LastRejection(error)),
}
}
}
}

impl<$($ident),*, $last> IntoResponse for $either<$($ident),*, $last>
where
$($ident: IntoResponse),*,
Expand Down Expand Up @@ -312,3 +411,79 @@ where
}
}
}

#[cfg(test)]
mod tests {
use std::future::Future;

use axum::body::Body;
use axum::extract::rejection::StringRejection;
use axum::extract::{FromRequest, Request, State};
use bytes::Bytes;
use http_body_util::Full;

use super::*;

#[derive(Debug, PartialEq)]
struct False;

impl<S> FromRequestParts<S> for False {
type Rejection = ();

fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
std::future::ready(Err(()))
}
}

#[tokio::test]
async fn either_from_request() {
// The body is by design not valid UTF-8.
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));

let either = Either4::<String, String, Request, Bytes>::from_request(request, &())
.await
.unwrap();

assert!(matches!(either, Either4::E3(_)));
}

#[tokio::test]
async fn either_from_request_rejection() {
// The body is by design not valid UTF-8.
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));

let either = Either::<String, String>::from_request(request, &())
.await
.unwrap_err();

assert!(matches!(
either,
EitherRejection::LastRejection(StringRejection::InvalidUtf8(_))
));
}

#[tokio::test]
async fn either_from_request_parts() {
let (mut parts, _) = Request::new(Body::empty()).into_parts();

let either = Either3::<False, False, State<()>>::from_request_parts(&mut parts, &())
.await
.unwrap();

assert!(matches!(either, Either3::E3(State(()))));
}

#[tokio::test]
async fn either_from_request_or_parts() {
let request = Request::new(Body::empty());

let either = Either::<False, Bytes>::from_request(request, &())
.await
.unwrap();

assert_eq!(either, Either::E2(Bytes::new()));
}
}
Loading