@@ -2,6 +2,8 @@ use crate::error::S3Result;
2
2
3
3
use std:: borrow:: Cow ;
4
4
5
+ use rust_utils:: default:: default;
6
+
5
7
#[ derive( Debug , Clone ) ]
6
8
pub struct VirtualHost < ' a > {
7
9
domain : Cow < ' a , str > ,
@@ -45,32 +47,248 @@ pub trait S3Host: Send + Sync + 'static {
45
47
fn parse_host_header < ' a > ( & ' a self , host : & ' a str ) -> S3Result < VirtualHost < ' a > > ;
46
48
}
47
49
50
+ #[ derive( Debug , Clone , PartialEq , Eq , thiserror:: Error ) ]
51
+ pub enum DomainError {
52
+ #[ error( "The domain is invalid" ) ]
53
+ InvalidDomain ,
54
+
55
+ #[ error( "Some subdomains overlap with each other" ) ]
56
+ OverlappingSubdomains ,
57
+
58
+ #[ error( "No base domains are specified" ) ]
59
+ ZeroDomains ,
60
+ }
61
+
62
+ /// Naive check for a valid domain.
63
+ fn is_valid_domain ( mut s : & str ) -> bool {
64
+ if s. is_empty ( ) {
65
+ return false ;
66
+ }
67
+
68
+ if let Some ( ( host, port) ) = s. split_once ( ':' ) {
69
+ if port. is_empty ( ) {
70
+ return false ;
71
+ }
72
+
73
+ if port. parse :: < u16 > ( ) . is_err ( ) {
74
+ return false ;
75
+ }
76
+
77
+ s = host;
78
+ }
79
+
80
+ for part in s. split ( '.' ) {
81
+ if part. is_empty ( ) {
82
+ return false ;
83
+ }
84
+
85
+ if part. as_bytes ( ) . iter ( ) . any ( |& b| !b. is_ascii_alphanumeric ( ) && b != b'-' ) {
86
+ return false ;
87
+ }
88
+ }
89
+
90
+ true
91
+ }
92
+
93
+ fn parse_host_header < ' a > ( base_domain : & ' a str , host : & ' a str ) -> Option < VirtualHost < ' a > > {
94
+ if host == base_domain {
95
+ return Some ( VirtualHost :: new ( base_domain) ) ;
96
+ }
97
+
98
+ if let Some ( bucket) = host. strip_suffix ( base_domain) . and_then ( |h| h. strip_suffix ( '.' ) ) {
99
+ return Some ( VirtualHost :: with_bucket ( base_domain, bucket) ) ;
100
+ } ;
101
+
102
+ None
103
+ }
104
+
105
+ #[ derive( Debug ) ]
48
106
pub struct SingleDomain {
49
107
base_domain : String ,
50
108
}
51
109
52
110
impl SingleDomain {
53
- #[ must_use]
54
- pub fn new ( base_domain : impl Into < String > ) -> Self {
55
- Self {
56
- base_domain : base_domain. into ( ) ,
111
+ /// Create a new `SingleDomain` with the base domain.
112
+ ///
113
+ /// # Errors
114
+ /// Returns an error if the base domain is invalid.
115
+ pub fn new ( base_domain : & str ) -> Result < Self , DomainError > {
116
+ if !is_valid_domain ( base_domain) {
117
+ return Err ( DomainError :: InvalidDomain ) ;
57
118
}
119
+
120
+ Ok ( Self {
121
+ base_domain : base_domain. into ( ) ,
122
+ } )
58
123
}
59
124
}
60
125
61
126
impl S3Host for SingleDomain {
62
127
fn parse_host_header < ' a > ( & ' a self , host : & ' a str ) -> S3Result < VirtualHost < ' a > > {
63
128
let base_domain = self . base_domain . as_str ( ) ;
64
129
65
- if host == base_domain {
66
- return Ok ( VirtualHost :: new ( base_domain ) ) ;
130
+ if let Some ( vh ) = parse_host_header ( base_domain, host ) {
131
+ return Ok ( vh ) ;
67
132
}
68
133
69
- if let Some ( bucket) = host. strip_suffix ( & self . base_domain ) . and_then ( |h| h. strip_suffix ( '.' ) ) {
70
- return Ok ( VirtualHost :: with_bucket ( base_domain, bucket) ) ;
71
- } ;
134
+ if is_valid_domain ( host) {
135
+ let bucket = host. to_ascii_lowercase ( ) ;
136
+ return Ok ( VirtualHost :: with_bucket ( host, bucket) ) ;
137
+ }
138
+
139
+ Err ( s3_error ! ( InvalidRequest , "Invalid host header" ) )
140
+ }
141
+ }
142
+
143
+ #[ derive( Debug ) ]
144
+ pub struct MultiDomain {
145
+ base_domains : Vec < String > ,
146
+ }
147
+
148
+ impl MultiDomain {
149
+ /// Create a new `MultiDomain` with the base domains.
150
+ ///
151
+ /// # Errors
152
+ /// Returns an error if
153
+ /// + any of the base domains are invalid.
154
+ /// + any of the base domains overlap with each other.
155
+ /// + no base domains are specified.
156
+ pub fn new < I > ( base_domains : I ) -> Result < Self , DomainError >
157
+ where
158
+ I : IntoIterator ,
159
+ I :: Item : AsRef < str > ,
160
+ {
161
+ let mut v: Vec < String > = default ( ) ;
162
+
163
+ for domain in base_domains {
164
+ let domain = domain. as_ref ( ) ;
165
+
166
+ if !is_valid_domain ( domain) {
167
+ return Err ( DomainError :: InvalidDomain ) ;
168
+ }
169
+
170
+ for other in & v {
171
+ if domain. ends_with ( other) || other. ends_with ( domain) {
172
+ return Err ( DomainError :: OverlappingSubdomains ) ;
173
+ }
174
+ }
175
+
176
+ v. push ( domain. to_owned ( ) ) ;
177
+ }
178
+
179
+ if v. is_empty ( ) {
180
+ return Err ( DomainError :: ZeroDomains ) ;
181
+ }
182
+
183
+ Ok ( Self { base_domains : v } )
184
+ }
185
+ }
186
+
187
+ impl S3Host for MultiDomain {
188
+ fn parse_host_header < ' a > ( & ' a self , host : & ' a str ) -> S3Result < VirtualHost < ' a > > {
189
+ for base_domain in & self . base_domains {
190
+ if let Some ( vh) = parse_host_header ( base_domain, host) {
191
+ return Ok ( vh) ;
192
+ }
193
+ }
194
+
195
+ if is_valid_domain ( host) {
196
+ let bucket = host. to_ascii_lowercase ( ) ;
197
+ return Ok ( VirtualHost :: with_bucket ( host, bucket) ) ;
198
+ }
199
+
200
+ Err ( s3_error ! ( InvalidRequest , "Invalid host header" ) )
201
+ }
202
+ }
203
+
204
+ #[ cfg( test) ]
205
+ mod tests {
206
+ use super :: * ;
207
+
208
+ use crate :: S3ErrorCode ;
209
+
210
+ #[ test]
211
+ fn single_domain_new ( ) {
212
+ let domain = "example.com" ;
213
+ let result = SingleDomain :: new ( domain) ;
214
+ let sd = result. unwrap ( ) ;
215
+ assert_eq ! ( sd. base_domain, domain) ;
216
+
217
+ let domain = "example.com.org" ;
218
+ let result = SingleDomain :: new ( domain) ;
219
+ let sd = result. unwrap ( ) ;
220
+ assert_eq ! ( sd. base_domain, domain) ;
221
+
222
+ let domain = "example.com." ;
223
+ let result = SingleDomain :: new ( domain) ;
224
+ let err = result. unwrap_err ( ) ;
225
+ assert ! ( matches!( err, DomainError :: InvalidDomain ) ) ;
226
+
227
+ let domain = "example.com:" ;
228
+ let result = SingleDomain :: new ( domain) ;
229
+ let err = result. unwrap_err ( ) ;
230
+ assert ! ( matches!( err, DomainError :: InvalidDomain ) ) ;
231
+
232
+ let domain = "example.com:80" ;
233
+ let result = SingleDomain :: new ( domain) ;
234
+ assert ! ( result. is_ok( ) ) ;
235
+ }
236
+
237
+ #[ test]
238
+ fn multi_domain_new ( ) {
239
+ let domains = [ "example.com" , "example.org" ] ;
240
+ let result = MultiDomain :: new ( & domains) ;
241
+ let md = result. unwrap ( ) ;
242
+ assert_eq ! ( md. base_domains, domains) ;
243
+
244
+ let domains = [ "example.com" , "example.com" ] ;
245
+ let result = MultiDomain :: new ( & domains) ;
246
+ let err = result. unwrap_err ( ) ;
247
+ assert ! ( matches!( err, DomainError :: OverlappingSubdomains ) ) ;
248
+
249
+ let domains = [ "example.com" , "example.com.org" ] ;
250
+ let result = MultiDomain :: new ( & domains) ;
251
+ let md = result. unwrap ( ) ;
252
+ assert_eq ! ( md. base_domains, domains) ;
253
+
254
+ let domains: [ & str ; 0 ] = [ ] ;
255
+ let result = MultiDomain :: new ( & domains) ;
256
+ let err = result. unwrap_err ( ) ;
257
+ assert ! ( matches!( err, DomainError :: ZeroDomains ) ) ;
258
+ }
259
+
260
+ #[ test]
261
+ fn multi_domain_parse ( ) {
262
+ let domains = [ "example.com" , "example.org" ] ;
263
+ let md = MultiDomain :: new ( domains. iter ( ) . copied ( ) ) . unwrap ( ) ;
264
+
265
+ let host = "example.com" ;
266
+ let result = md. parse_host_header ( host) ;
267
+ let vh = result. unwrap ( ) ;
268
+ assert_eq ! ( vh. domain( ) , host) ;
269
+ assert_eq ! ( vh. bucket( ) , None ) ;
270
+
271
+ let host = "example.org" ;
272
+ let result = md. parse_host_header ( host) ;
273
+ let vh = result. unwrap ( ) ;
274
+ assert_eq ! ( vh. domain( ) , host) ;
275
+ assert_eq ! ( vh. bucket( ) , None ) ;
276
+
277
+ let host = "example.com.org" ;
278
+ let result = md. parse_host_header ( host) ;
279
+ let vh = result. unwrap ( ) ;
280
+ assert_eq ! ( vh. domain( ) , host) ;
281
+ assert_eq ! ( vh. bucket( ) , Some ( "example.com.org" ) ) ;
282
+
283
+ let host = "example.com.org." ;
284
+ let result = md. parse_host_header ( host) ;
285
+ let err = result. unwrap_err ( ) ;
286
+ assert ! ( matches!( err. code( ) , S3ErrorCode :: InvalidRequest ) ) ;
72
287
73
- let bucket = host. to_ascii_lowercase ( ) ;
74
- Ok ( VirtualHost :: with_bucket ( host, bucket) )
288
+ let host = "example.com.org.example.com" ;
289
+ let result = md. parse_host_header ( host) ;
290
+ let vh = result. unwrap ( ) ;
291
+ assert_eq ! ( vh. domain( ) , "example.com" ) ;
292
+ assert_eq ! ( vh. bucket( ) , Some ( "example.com.org" ) ) ;
75
293
}
76
294
}
0 commit comments