Skip to content

Commit 5340484

Browse files
authored
feat(s3s/host): add MultiDomain (#179)
* refactor: change domain arg * feat(s3s/host): add MultiDomain * fix
1 parent c94f314 commit 5340484

File tree

7 files changed

+247
-27
lines changed

7 files changed

+247
-27
lines changed

crates/s3s-fs/src/main.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ use s3s_fs::FileSystem;
55
use s3s_fs::Result;
66

77
use s3s::auth::SimpleAuth;
8-
use s3s::host::SingleDomain;
8+
use s3s::host::MultiDomain;
99
use s3s::service::S3ServiceBuilder;
1010

1111
use std::io::IsTerminal;
12+
use std::ops::Not;
1213
use std::path::PathBuf;
1314

1415
use tokio::net::TcpListener;
@@ -38,9 +39,9 @@ struct Opt {
3839
#[arg(long)]
3940
secret_key: Option<String>,
4041

41-
/// Domain name used for virtual-hosted-style requests.
42+
/// Domain names used for virtual-hosted-style requests.
4243
#[arg(long)]
43-
domain_name: Option<String>,
44+
domain: Vec<String>,
4445

4546
/// Root directory of stored data.
4647
root: PathBuf,
@@ -70,7 +71,7 @@ fn check_cli_args(opt: &Opt) {
7071
cmd.error(ErrorKind::MissingRequiredArgument, msg).exit();
7172
}
7273

73-
if let Some(ref s) = opt.domain_name {
74+
for s in &opt.domain {
7475
if s.contains('/') {
7576
let msg = format!("expected domain name, found URL-like string: {s:?}");
7677
cmd.error(ErrorKind::InvalidValue, msg).exit();
@@ -103,8 +104,8 @@ async fn run(opt: Opt) -> Result {
103104
}
104105

105106
// Enable parsing virtual-hosted-style requests
106-
if let Some(domain_name) = opt.domain_name {
107-
b.set_host(SingleDomain::new(domain_name));
107+
if opt.domain.is_empty().not() {
108+
b.set_host(MultiDomain::new(&opt.domain)?);
108109
info!("virtual-hosted-style requests are enabled");
109110
}
110111

crates/s3s-fs/tests/it_aws.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ fn config() -> &'static SdkConfig {
6565
let service = {
6666
let mut b = S3ServiceBuilder::new(fs);
6767
b.set_auth(SimpleAuth::from_single(cred.access_key_id(), cred.secret_access_key()));
68-
b.set_host(SingleDomain::new(DOMAIN_NAME));
68+
b.set_host(SingleDomain::new(DOMAIN_NAME).unwrap());
6969
b.build()
7070
};
7171

crates/s3s-proxy/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct Opt {
2626
port: u16,
2727

2828
#[clap(long)]
29-
domain_name: Option<String>,
29+
domain: Option<String>,
3030

3131
#[clap(long)]
3232
endpoint_url: String,
@@ -66,8 +66,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
6666
}
6767

6868
// Enable parsing virtual-hosted-style requests
69-
if let Some(domain_name) = opt.domain_name {
70-
b.set_host(SingleDomain::new(domain_name));
69+
if let Some(domain) = opt.domain {
70+
b.set_host(SingleDomain::new(&domain)?);
7171
}
7272

7373
b.build()

crates/s3s/src/host.rs

Lines changed: 229 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use crate::error::S3Result;
22

33
use std::borrow::Cow;
44

5+
use rust_utils::default::default;
6+
57
#[derive(Debug, Clone)]
68
pub struct VirtualHost<'a> {
79
domain: Cow<'a, str>,
@@ -45,32 +47,248 @@ pub trait S3Host: Send + Sync + 'static {
4547
fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>>;
4648
}
4749

50+
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
51+
pub enum DomainError {
52+
#[error("The domain is invalid")]
53+
InvalidDomain,
54+
55+
#[error("Some subdomains overlap with each other")]
56+
OverlappingSubdomains,
57+
58+
#[error("No base domains are specified")]
59+
ZeroDomains,
60+
}
61+
62+
/// Naive check for a valid domain.
63+
fn is_valid_domain(mut s: &str) -> bool {
64+
if s.is_empty() {
65+
return false;
66+
}
67+
68+
if let Some((host, port)) = s.split_once(':') {
69+
if port.is_empty() {
70+
return false;
71+
}
72+
73+
if port.parse::<u16>().is_err() {
74+
return false;
75+
}
76+
77+
s = host;
78+
}
79+
80+
for part in s.split('.') {
81+
if part.is_empty() {
82+
return false;
83+
}
84+
85+
if part.as_bytes().iter().any(|&b| !b.is_ascii_alphanumeric() && b != b'-') {
86+
return false;
87+
}
88+
}
89+
90+
true
91+
}
92+
93+
fn parse_host_header<'a>(base_domain: &'a str, host: &'a str) -> Option<VirtualHost<'a>> {
94+
if host == base_domain {
95+
return Some(VirtualHost::new(base_domain));
96+
}
97+
98+
if let Some(bucket) = host.strip_suffix(base_domain).and_then(|h| h.strip_suffix('.')) {
99+
return Some(VirtualHost::with_bucket(base_domain, bucket));
100+
};
101+
102+
None
103+
}
104+
105+
#[derive(Debug)]
48106
pub struct SingleDomain {
49107
base_domain: String,
50108
}
51109

52110
impl SingleDomain {
53-
#[must_use]
54-
pub fn new(base_domain: impl Into<String>) -> Self {
55-
Self {
56-
base_domain: base_domain.into(),
111+
/// Create a new `SingleDomain` with the base domain.
112+
///
113+
/// # Errors
114+
/// Returns an error if the base domain is invalid.
115+
pub fn new(base_domain: &str) -> Result<Self, DomainError> {
116+
if !is_valid_domain(base_domain) {
117+
return Err(DomainError::InvalidDomain);
57118
}
119+
120+
Ok(Self {
121+
base_domain: base_domain.into(),
122+
})
58123
}
59124
}
60125

61126
impl S3Host for SingleDomain {
62127
fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>> {
63128
let base_domain = self.base_domain.as_str();
64129

65-
if host == base_domain {
66-
return Ok(VirtualHost::new(base_domain));
130+
if let Some(vh) = parse_host_header(base_domain, host) {
131+
return Ok(vh);
67132
}
68133

69-
if let Some(bucket) = host.strip_suffix(&self.base_domain).and_then(|h| h.strip_suffix('.')) {
70-
return Ok(VirtualHost::with_bucket(base_domain, bucket));
71-
};
134+
if is_valid_domain(host) {
135+
let bucket = host.to_ascii_lowercase();
136+
return Ok(VirtualHost::with_bucket(host, bucket));
137+
}
138+
139+
Err(s3_error!(InvalidRequest, "Invalid host header"))
140+
}
141+
}
142+
143+
#[derive(Debug)]
144+
pub struct MultiDomain {
145+
base_domains: Vec<String>,
146+
}
147+
148+
impl MultiDomain {
149+
/// Create a new `MultiDomain` with the base domains.
150+
///
151+
/// # Errors
152+
/// Returns an error if
153+
/// + any of the base domains are invalid.
154+
/// + any of the base domains overlap with each other.
155+
/// + no base domains are specified.
156+
pub fn new<I>(base_domains: I) -> Result<Self, DomainError>
157+
where
158+
I: IntoIterator,
159+
I::Item: AsRef<str>,
160+
{
161+
let mut v: Vec<String> = default();
162+
163+
for domain in base_domains {
164+
let domain = domain.as_ref();
165+
166+
if !is_valid_domain(domain) {
167+
return Err(DomainError::InvalidDomain);
168+
}
169+
170+
for other in &v {
171+
if domain.ends_with(other) || other.ends_with(domain) {
172+
return Err(DomainError::OverlappingSubdomains);
173+
}
174+
}
175+
176+
v.push(domain.to_owned());
177+
}
178+
179+
if v.is_empty() {
180+
return Err(DomainError::ZeroDomains);
181+
}
182+
183+
Ok(Self { base_domains: v })
184+
}
185+
}
186+
187+
impl S3Host for MultiDomain {
188+
fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>> {
189+
for base_domain in &self.base_domains {
190+
if let Some(vh) = parse_host_header(base_domain, host) {
191+
return Ok(vh);
192+
}
193+
}
194+
195+
if is_valid_domain(host) {
196+
let bucket = host.to_ascii_lowercase();
197+
return Ok(VirtualHost::with_bucket(host, bucket));
198+
}
199+
200+
Err(s3_error!(InvalidRequest, "Invalid host header"))
201+
}
202+
}
203+
204+
#[cfg(test)]
205+
mod tests {
206+
use super::*;
207+
208+
use crate::S3ErrorCode;
209+
210+
#[test]
211+
fn single_domain_new() {
212+
let domain = "example.com";
213+
let result = SingleDomain::new(domain);
214+
let sd = result.unwrap();
215+
assert_eq!(sd.base_domain, domain);
216+
217+
let domain = "example.com.org";
218+
let result = SingleDomain::new(domain);
219+
let sd = result.unwrap();
220+
assert_eq!(sd.base_domain, domain);
221+
222+
let domain = "example.com.";
223+
let result = SingleDomain::new(domain);
224+
let err = result.unwrap_err();
225+
assert!(matches!(err, DomainError::InvalidDomain));
226+
227+
let domain = "example.com:";
228+
let result = SingleDomain::new(domain);
229+
let err = result.unwrap_err();
230+
assert!(matches!(err, DomainError::InvalidDomain));
231+
232+
let domain = "example.com:80";
233+
let result = SingleDomain::new(domain);
234+
assert!(result.is_ok());
235+
}
236+
237+
#[test]
238+
fn multi_domain_new() {
239+
let domains = ["example.com", "example.org"];
240+
let result = MultiDomain::new(&domains);
241+
let md = result.unwrap();
242+
assert_eq!(md.base_domains, domains);
243+
244+
let domains = ["example.com", "example.com"];
245+
let result = MultiDomain::new(&domains);
246+
let err = result.unwrap_err();
247+
assert!(matches!(err, DomainError::OverlappingSubdomains));
248+
249+
let domains = ["example.com", "example.com.org"];
250+
let result = MultiDomain::new(&domains);
251+
let md = result.unwrap();
252+
assert_eq!(md.base_domains, domains);
253+
254+
let domains: [&str; 0] = [];
255+
let result = MultiDomain::new(&domains);
256+
let err = result.unwrap_err();
257+
assert!(matches!(err, DomainError::ZeroDomains));
258+
}
259+
260+
#[test]
261+
fn multi_domain_parse() {
262+
let domains = ["example.com", "example.org"];
263+
let md = MultiDomain::new(domains.iter().copied()).unwrap();
264+
265+
let host = "example.com";
266+
let result = md.parse_host_header(host);
267+
let vh = result.unwrap();
268+
assert_eq!(vh.domain(), host);
269+
assert_eq!(vh.bucket(), None);
270+
271+
let host = "example.org";
272+
let result = md.parse_host_header(host);
273+
let vh = result.unwrap();
274+
assert_eq!(vh.domain(), host);
275+
assert_eq!(vh.bucket(), None);
276+
277+
let host = "example.com.org";
278+
let result = md.parse_host_header(host);
279+
let vh = result.unwrap();
280+
assert_eq!(vh.domain(), host);
281+
assert_eq!(vh.bucket(), Some("example.com.org"));
282+
283+
let host = "example.com.org.";
284+
let result = md.parse_host_header(host);
285+
let err = result.unwrap_err();
286+
assert!(matches!(err.code(), S3ErrorCode::InvalidRequest));
72287

73-
let bucket = host.to_ascii_lowercase();
74-
Ok(VirtualHost::with_bucket(host, bucket))
288+
let host = "example.com.org.example.com";
289+
let result = md.parse_host_header(host);
290+
let vh = result.unwrap();
291+
assert_eq!(vh.domain(), "example.com");
292+
assert_eq!(vh.bucket(), Some("example.com.org"));
75293
}
76294
}

crates/s3s/src/path.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ mod tests {
257257
#[test]
258258
fn virtual_hosted_style() {
259259
{
260-
let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com");
260+
let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com").unwrap();
261261
let host = "s3.us-east-1.amazonaws.com";
262262
let uri_path = "/example.com/homepage.html";
263263
let vh = s3_host.parse_host_header(host).unwrap();
@@ -267,7 +267,7 @@ mod tests {
267267
}
268268

269269
{
270-
let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com");
270+
let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com").unwrap();
271271
let host = "doc-example-bucket1.eu.s3.eu-west-1.amazonaws.com";
272272
let uri_path = "/homepage.html";
273273
let vh = s3_host.parse_host_header(host).unwrap();
@@ -277,7 +277,7 @@ mod tests {
277277
}
278278

279279
{
280-
let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com");
280+
let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com").unwrap();
281281
let host = "doc-example-bucket1.eu.s3.eu-west-1.amazonaws.com";
282282
let uri_path = "/";
283283
let vh = s3_host.parse_host_header(host).unwrap();
@@ -287,7 +287,7 @@ mod tests {
287287
}
288288

289289
{
290-
let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com");
290+
let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com").unwrap();
291291
let host = "example.com";
292292
let uri_path = "/homepage.html";
293293
let vh = s3_host.parse_host_header(host).unwrap();

scripts/s3s-fs.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ s3s-fs \
1414
--secret-key SKEXAMPLES3S \
1515
--host localhost \
1616
--port 8014 \
17-
--domain-name localhost:8014 \
17+
--domain localhost:8014 \
18+
--domain localhost \
1819
"$DATA_DIR"

scripts/s3s-proxy.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ export RUST_BACKTRACE=full
2121
s3s-proxy \
2222
--host localhost \
2323
--port 8014 \
24-
--domain-name localhost:8014 \
24+
--domain localhost:8014 \
2525
--endpoint-url http://localhost:9000

0 commit comments

Comments
 (0)