1
1
use super :: rejection:: { FailedToResolveHost , HostRejection } ;
2
- use axum:: extract:: FromRequestParts ;
2
+ use axum:: {
3
+ extract:: { FromRequestParts , OptionalFromRequestParts } ,
4
+ RequestPartsExt ,
5
+ } ;
3
6
use http:: {
4
7
header:: { HeaderMap , FORWARDED } ,
5
8
request:: Parts ,
6
9
uri:: Authority ,
7
10
} ;
11
+ use std:: convert:: Infallible ;
8
12
9
13
const X_FORWARDED_HOST_HEADER_KEY : & str = "X-Forwarded-Host" ;
10
14
@@ -31,31 +35,50 @@ where
31
35
type Rejection = HostRejection ;
32
36
33
37
async fn from_request_parts ( parts : & mut Parts , _state : & S ) -> Result < Self , Self :: Rejection > {
38
+ parts
39
+ . extract :: < Option < Host > > ( )
40
+ . await
41
+ . ok ( )
42
+ . flatten ( )
43
+ . ok_or ( HostRejection :: FailedToResolveHost ( FailedToResolveHost ) )
44
+ }
45
+ }
46
+
47
+ impl < S > OptionalFromRequestParts < S > for Host
48
+ where
49
+ S : Send + Sync ,
50
+ {
51
+ type Rejection = Infallible ;
52
+
53
+ async fn from_request_parts (
54
+ parts : & mut Parts ,
55
+ _state : & S ,
56
+ ) -> Result < Option < Self > , Self :: Rejection > {
34
57
if let Some ( host) = parse_forwarded ( & parts. headers ) {
35
- return Ok ( Host ( host. to_owned ( ) ) ) ;
58
+ return Ok ( Some ( Host ( host. to_owned ( ) ) ) ) ;
36
59
}
37
60
38
61
if let Some ( host) = parts
39
62
. headers
40
63
. get ( X_FORWARDED_HOST_HEADER_KEY )
41
64
. and_then ( |host| host. to_str ( ) . ok ( ) )
42
65
{
43
- return Ok ( Host ( host. to_owned ( ) ) ) ;
66
+ return Ok ( Some ( Host ( host. to_owned ( ) ) ) ) ;
44
67
}
45
68
46
69
if let Some ( host) = parts
47
70
. headers
48
71
. get ( http:: header:: HOST )
49
72
. and_then ( |host| host. to_str ( ) . ok ( ) )
50
73
{
51
- return Ok ( Host ( host. to_owned ( ) ) ) ;
74
+ return Ok ( Some ( Host ( host. to_owned ( ) ) ) ) ;
52
75
}
53
76
54
77
if let Some ( authority) = parts. uri . authority ( ) {
55
- return Ok ( Host ( parse_authority ( authority) . to_owned ( ) ) ) ;
78
+ return Ok ( Some ( Host ( parse_authority ( authority) . to_owned ( ) ) ) ) ;
56
79
}
57
80
58
- Err ( HostRejection :: FailedToResolveHost ( FailedToResolveHost ) )
81
+ Ok ( None )
59
82
}
60
83
}
61
84
@@ -148,18 +171,40 @@ mod tests {
148
171
async fn ip4_uri_host ( ) {
149
172
let mut parts = Request :: new ( ( ) ) . into_parts ( ) . 0 ;
150
173
parts. uri = "https://127.0.0.1:1234/image.jpg" . parse ( ) . unwrap ( ) ;
151
- let host = Host :: from_request_parts ( & mut parts , & ( ) ) . await . unwrap ( ) ;
174
+ let host = parts . extract :: < Host > ( ) . await . unwrap ( ) ;
152
175
assert_eq ! ( host. 0 , "127.0.0.1:1234" ) ;
153
176
}
154
177
155
178
#[ crate :: test]
156
179
async fn ip6_uri_host ( ) {
157
180
let mut parts = Request :: new ( ( ) ) . into_parts ( ) . 0 ;
158
181
parts. uri = "http://cool:user@[::1]:456/file.txt" . parse ( ) . unwrap ( ) ;
159
- let host = Host :: from_request_parts ( & mut parts , & ( ) ) . await . unwrap ( ) ;
182
+ let host = parts . extract :: < Host > ( ) . await . unwrap ( ) ;
160
183
assert_eq ! ( host. 0 , "[::1]:456" ) ;
161
184
}
162
185
186
+ #[ crate :: test]
187
+ async fn missing_host ( ) {
188
+ let mut parts = Request :: new ( ( ) ) . into_parts ( ) . 0 ;
189
+ let host = parts. extract :: < Host > ( ) . await . unwrap_err ( ) ;
190
+ assert ! ( matches!( host, HostRejection :: FailedToResolveHost ( _) ) ) ;
191
+ }
192
+
193
+ #[ crate :: test]
194
+ async fn optional_extractor ( ) {
195
+ let mut parts = Request :: new ( ( ) ) . into_parts ( ) . 0 ;
196
+ parts. uri = "https://127.0.0.1:1234/image.jpg" . parse ( ) . unwrap ( ) ;
197
+ let host = parts. extract :: < Option < Host > > ( ) . await . unwrap ( ) ;
198
+ assert ! ( host. is_some( ) ) ;
199
+ }
200
+
201
+ #[ crate :: test]
202
+ async fn optional_extractor_none ( ) {
203
+ let mut parts = Request :: new ( ( ) ) . into_parts ( ) . 0 ;
204
+ let host = parts. extract :: < Option < Host > > ( ) . await . unwrap ( ) ;
205
+ assert ! ( host. is_none( ) ) ;
206
+ }
207
+
163
208
#[ test]
164
209
fn forwarded_parsing ( ) {
165
210
// the basic case
0 commit comments