Skip to content

Commit 769e406

Browse files
authored
Implement the OptionalFromRequestParts trait for the Host extractor (#3177)
1 parent 869ba86 commit 769e406

File tree

1 file changed

+53
-8
lines changed

1 file changed

+53
-8
lines changed

axum-extra/src/extract/host.rs

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
use super::rejection::{FailedToResolveHost, HostRejection};
2-
use axum::extract::FromRequestParts;
2+
use axum::{
3+
extract::{FromRequestParts, OptionalFromRequestParts},
4+
RequestPartsExt,
5+
};
36
use http::{
47
header::{HeaderMap, FORWARDED},
58
request::Parts,
69
uri::Authority,
710
};
11+
use std::convert::Infallible;
812

913
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
1014

@@ -31,31 +35,50 @@ where
3135
type Rejection = HostRejection;
3236

3337
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
38+
parts
39+
.extract::<Option<Host>>()
40+
.await
41+
.ok()
42+
.flatten()
43+
.ok_or(HostRejection::FailedToResolveHost(FailedToResolveHost))
44+
}
45+
}
46+
47+
impl<S> OptionalFromRequestParts<S> for Host
48+
where
49+
S: Send + Sync,
50+
{
51+
type Rejection = Infallible;
52+
53+
async fn from_request_parts(
54+
parts: &mut Parts,
55+
_state: &S,
56+
) -> Result<Option<Self>, Self::Rejection> {
3457
if let Some(host) = parse_forwarded(&parts.headers) {
35-
return Ok(Host(host.to_owned()));
58+
return Ok(Some(Host(host.to_owned())));
3659
}
3760

3861
if let Some(host) = parts
3962
.headers
4063
.get(X_FORWARDED_HOST_HEADER_KEY)
4164
.and_then(|host| host.to_str().ok())
4265
{
43-
return Ok(Host(host.to_owned()));
66+
return Ok(Some(Host(host.to_owned())));
4467
}
4568

4669
if let Some(host) = parts
4770
.headers
4871
.get(http::header::HOST)
4972
.and_then(|host| host.to_str().ok())
5073
{
51-
return Ok(Host(host.to_owned()));
74+
return Ok(Some(Host(host.to_owned())));
5275
}
5376

5477
if let Some(authority) = parts.uri.authority() {
55-
return Ok(Host(parse_authority(authority).to_owned()));
78+
return Ok(Some(Host(parse_authority(authority).to_owned())));
5679
}
5780

58-
Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
81+
Ok(None)
5982
}
6083
}
6184

@@ -148,18 +171,40 @@ mod tests {
148171
async fn ip4_uri_host() {
149172
let mut parts = Request::new(()).into_parts().0;
150173
parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
151-
let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
174+
let host = parts.extract::<Host>().await.unwrap();
152175
assert_eq!(host.0, "127.0.0.1:1234");
153176
}
154177

155178
#[crate::test]
156179
async fn ip6_uri_host() {
157180
let mut parts = Request::new(()).into_parts().0;
158181
parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
159-
let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
182+
let host = parts.extract::<Host>().await.unwrap();
160183
assert_eq!(host.0, "[::1]:456");
161184
}
162185

186+
#[crate::test]
187+
async fn missing_host() {
188+
let mut parts = Request::new(()).into_parts().0;
189+
let host = parts.extract::<Host>().await.unwrap_err();
190+
assert!(matches!(host, HostRejection::FailedToResolveHost(_)));
191+
}
192+
193+
#[crate::test]
194+
async fn optional_extractor() {
195+
let mut parts = Request::new(()).into_parts().0;
196+
parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
197+
let host = parts.extract::<Option<Host>>().await.unwrap();
198+
assert!(host.is_some());
199+
}
200+
201+
#[crate::test]
202+
async fn optional_extractor_none() {
203+
let mut parts = Request::new(()).into_parts().0;
204+
let host = parts.extract::<Option<Host>>().await.unwrap();
205+
assert!(host.is_none());
206+
}
207+
163208
#[test]
164209
fn forwarded_parsing() {
165210
// the basic case

0 commit comments

Comments
 (0)