Skip to content

Commit c94f314

Browse files
authored
feat(s3s/host): add S3Host (#178)
1 parent 75bc2ff commit c94f314

File tree

9 files changed

+155
-75
lines changed

9 files changed

+155
-75
lines changed

crates/s3s-fs/src/main.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use s3s_fs::FileSystem;
55
use s3s_fs::Result;
66

77
use s3s::auth::SimpleAuth;
8+
use s3s::host::SingleDomain;
89
use s3s::service::S3ServiceBuilder;
910

1011
use std::io::IsTerminal;
@@ -103,7 +104,7 @@ async fn run(opt: Opt) -> Result {
103104

104105
// Enable parsing virtual-hosted-style requests
105106
if let Some(domain_name) = opt.domain_name {
106-
b.set_base_domain(domain_name);
107+
b.set_host(SingleDomain::new(domain_name));
107108
info!("virtual-hosted-style requests are enabled");
108109
}
109110

crates/s3s-fs/tests/it_aws.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
)]
66

77
use s3s::auth::SimpleAuth;
8+
use s3s::host::SingleDomain;
89
use s3s::service::S3ServiceBuilder;
910
use s3s_fs::FileSystem;
1011

@@ -64,7 +65,7 @@ fn config() -> &'static SdkConfig {
6465
let service = {
6566
let mut b = S3ServiceBuilder::new(fs);
6667
b.set_auth(SimpleAuth::from_single(cred.access_key_id(), cred.secret_access_key()));
67-
b.set_base_domain(DOMAIN_NAME);
68+
b.set_host(SingleDomain::new(DOMAIN_NAME));
6869
b.build()
6970
};
7071

crates/s3s-proxy/src/main.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#![deny(clippy::all, clippy::pedantic)]
33

44
use s3s::auth::SimpleAuth;
5+
use s3s::host::SingleDomain;
56
use s3s::service::S3ServiceBuilder;
67
use tokio::net::TcpListener;
78

@@ -66,7 +67,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
6667

6768
// Enable parsing virtual-hosted-style requests
6869
if let Some(domain_name) = opt.domain_name {
69-
b.set_base_domain(domain_name);
70+
b.set_host(SingleDomain::new(domain_name));
7071
}
7172

7273
b.build()

crates/s3s/src/host.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use crate::error::S3Result;
2+
3+
use std::borrow::Cow;
4+
5+
#[derive(Debug, Clone)]
6+
pub struct VirtualHost<'a> {
7+
domain: Cow<'a, str>,
8+
bucket: Option<Cow<'a, str>>,
9+
// pub(crate) region: Option<Cow<'a, str>>,
10+
}
11+
12+
impl<'a> VirtualHost<'a> {
13+
pub fn new(domain: impl Into<Cow<'a, str>>) -> Self {
14+
Self {
15+
domain: domain.into(),
16+
bucket: None,
17+
}
18+
}
19+
20+
pub fn with_bucket(domain: impl Into<Cow<'a, str>>, bucket: impl Into<Cow<'a, str>>) -> Self {
21+
Self {
22+
domain: domain.into(),
23+
bucket: Some(bucket.into()),
24+
}
25+
}
26+
27+
#[inline]
28+
#[must_use]
29+
pub fn domain(&self) -> &str {
30+
self.domain.as_ref()
31+
}
32+
33+
#[inline]
34+
#[must_use]
35+
pub fn bucket(&self) -> Option<&str> {
36+
self.bucket.as_deref()
37+
}
38+
}
39+
40+
pub trait S3Host: Send + Sync + 'static {
41+
/// Parses the `Host` header of the HTTP request.
42+
///
43+
/// # Errors
44+
/// Returns an error if the `Host` is invalid for this service.
45+
fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>>;
46+
}
47+
48+
pub struct SingleDomain {
49+
base_domain: String,
50+
}
51+
52+
impl SingleDomain {
53+
#[must_use]
54+
pub fn new(base_domain: impl Into<String>) -> Self {
55+
Self {
56+
base_domain: base_domain.into(),
57+
}
58+
}
59+
}
60+
61+
impl S3Host for SingleDomain {
62+
fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>> {
63+
let base_domain = self.base_domain.as_str();
64+
65+
if host == base_domain {
66+
return Ok(VirtualHost::new(base_domain));
67+
}
68+
69+
if let Some(bucket) = host.strip_suffix(&self.base_domain).and_then(|h| h.strip_suffix('.')) {
70+
return Ok(VirtualHost::with_bucket(base_domain, bucket));
71+
};
72+
73+
let bucket = host.to_ascii_lowercase();
74+
Ok(VirtualHost::with_bucket(host, bucket))
75+
}
76+
}

crates/s3s/src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@ mod sig_v2;
3535
mod sig_v4;
3636
mod xml;
3737

38-
pub mod header;
39-
4038
pub mod auth;
39+
pub mod checksum;
4140
pub mod dto;
41+
pub mod header;
42+
pub mod host;
4243
pub mod path;
4344
pub mod service;
4445
pub mod stream;
4546

46-
pub mod checksum;
47-
4847
pub use self::error::*;
4948
pub use self::http::Body;
5049
pub use self::request::S3Request;

crates/s3s/src/ops/mod.rs

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::auth::S3Auth;
1313
use crate::auth::S3AuthContext;
1414
use crate::error::*;
1515
use crate::header;
16+
use crate::host::S3Host;
1617
use crate::http;
1718
use crate::http::Body;
1819
use crate::http::{OrderedHeaders, OrderedQs};
@@ -85,27 +86,16 @@ fn extract_host(req: &Request) -> S3Result<Option<String>> {
8586
Ok(Some(host.into()))
8687
}
8788

88-
fn is_socket_addr(host: &str) -> bool {
89+
fn is_socket_addr_or_ip_addr(host: &str) -> bool {
8990
host.parse::<SocketAddr>().is_ok() || host.parse::<IpAddr>().is_ok()
9091
}
9192

92-
fn extract_s3_path(host: Option<&str>, uri_path: &str, base_domain: Option<&str>) -> S3Result<S3Path> {
93-
let result = match (base_domain, host) {
94-
(Some(base_domain), Some(host)) if base_domain != host && !is_socket_addr(host) => {
95-
debug!(?base_domain, ?host, ?uri_path, "parsing virtual-hosted-style request");
96-
crate::path::parse_virtual_hosted_style(base_domain, host, uri_path)
97-
}
98-
_ => {
99-
debug!(?uri_path, "parsing path-style request");
100-
crate::path::parse_path_style(uri_path)
101-
}
102-
};
103-
104-
result.map_err(|err| match err {
93+
fn convert_parse_s3_path_error(err: &ParseS3PathError) -> S3Error {
94+
match err {
10595
ParseS3PathError::InvalidPath => s3_error!(InvalidURI),
10696
ParseS3PathError::InvalidBucketName => s3_error!(InvalidBucketName),
10797
ParseS3PathError::KeyTooLong => s3_error!(KeyTooLongError),
108-
})
98+
}
10999
}
110100

111101
fn extract_qs(req_uri: &Uri) -> S3Result<Option<OrderedQs>> {
@@ -183,9 +173,9 @@ pub async fn call(
183173
req: &mut Request,
184174
s3: &Arc<dyn S3>,
185175
auth: Option<&dyn S3Auth>,
186-
base_domain: Option<&str>,
176+
host: Option<&dyn S3Host>,
187177
) -> S3Result<Response> {
188-
let op = match prepare(req, auth, base_domain).await {
178+
let op = match prepare(req, auth, host).await {
189179
Ok(op) => op,
190180
Err(err) => {
191181
debug!(?err, "failed to prepare");
@@ -204,18 +194,40 @@ pub async fn call(
204194
Ok(resp)
205195
}
206196

207-
async fn prepare(req: &mut Request, auth: Option<&dyn S3Auth>, base_domain: Option<&str>) -> S3Result<&'static dyn Operation> {
197+
#[allow(clippy::too_many_lines)]
198+
async fn prepare(req: &mut Request, auth: Option<&dyn S3Auth>, s3_host: Option<&dyn S3Host>) -> S3Result<&'static dyn Operation> {
208199
let s3_path;
209200
let mut content_length;
210201
{
211202
let decoded_uri_path = urlencoding::decode(req.uri.path())
212203
.map_err(|_| S3ErrorCode::InvalidURI)?
213204
.into_owned();
214205

215-
let host = extract_host(req)?;
206+
let host_header = extract_host(req)?;
207+
let vh;
208+
let vh_bucket;
209+
{
210+
let result = 'parse: {
211+
if let (Some(host_header), Some(s3_host)) = (host_header.as_deref(), s3_host) {
212+
if !is_socket_addr_or_ip_addr(host_header) {
213+
debug!(?host_header, ?decoded_uri_path, "parsing virtual-hosted-style request");
214+
215+
vh = s3_host.parse_host_header(host_header)?;
216+
debug!(?vh);
216217

217-
req.s3ext.s3_path = Some(extract_s3_path(host.as_deref(), &decoded_uri_path, base_domain)?);
218-
s3_path = req.s3ext.s3_path.as_ref().unwrap();
218+
vh_bucket = vh.bucket();
219+
break 'parse crate::path::parse_virtual_hosted_style(vh_bucket, &decoded_uri_path);
220+
}
221+
}
222+
223+
debug!(?decoded_uri_path, "parsing path-style request");
224+
vh_bucket = None;
225+
crate::path::parse_path_style(&decoded_uri_path)
226+
};
227+
228+
req.s3ext.s3_path = Some(result.map_err(|err| convert_parse_s3_path_error(&err))?);
229+
s3_path = req.s3ext.s3_path.as_ref().unwrap();
230+
}
219231

220232
req.s3ext.qs = extract_qs(&req.uri)?;
221233
content_length = extract_content_length(req);
@@ -229,7 +241,6 @@ async fn prepare(req: &mut Request, auth: Option<&dyn S3Auth>, base_domain: Opti
229241
{
230242
let mut scx = SignatureContext {
231243
auth,
232-
base_domain,
233244

234245
req_method: &req.method,
235246
req_uri: &req.uri,
@@ -239,9 +250,8 @@ async fn prepare(req: &mut Request, auth: Option<&dyn S3Auth>, base_domain: Opti
239250
hs,
240251

241252
decoded_uri_path,
242-
s3_path,
253+
vh_bucket,
243254

244-
host: host.as_deref(),
245255
content_length,
246256
decoded_content_length,
247257
mime,

crates/s3s/src/ops/signature.rs

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use crate::error::*;
44
use crate::http;
55
use crate::http::{AwsChunkedStream, Body, Multipart};
66
use crate::http::{OrderedHeaders, OrderedQs};
7-
use crate::path::S3Path;
87
use crate::sig_v2;
98
use crate::sig_v2::{AuthorizationV2, PresignedUrlV2};
109
use crate::sig_v4;
@@ -54,7 +53,6 @@ fn extract_amz_date(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<AmzDate>> {
5453

5554
pub struct SignatureContext<'a> {
5655
pub auth: Option<&'a dyn S3Auth>,
57-
pub base_domain: Option<&'a str>,
5856

5957
pub req_method: &'a Method,
6058
pub req_uri: &'a Uri,
@@ -64,9 +62,8 @@ pub struct SignatureContext<'a> {
6462
pub hs: OrderedHeaders<'a>,
6563

6664
pub decoded_uri_path: String,
67-
pub s3_path: &'a S3Path,
65+
pub vh_bucket: Option<&'a str>,
6866

69-
pub host: Option<&'a str>,
7067
pub content_length: Option<u64>,
7168
pub mime: Option<Mime>,
7269
pub decoded_content_length: Option<usize>,
@@ -354,13 +351,6 @@ impl SignatureContext<'_> {
354351
None
355352
}
356353

357-
fn v2_virtual_hosted_bucket(&self) -> Option<&str> {
358-
match (self.base_domain, self.host) {
359-
(Some(base_domain), Some(host)) if base_domain != host => self.s3_path.get_bucket_name(),
360-
_ => None,
361-
}
362-
}
363-
364354
pub async fn v2_check_header_auth(&mut self, auth_v2: AuthorizationV2<'_>) -> S3Result<Credentials> {
365355
let method = &self.req_method;
366356

@@ -373,10 +363,14 @@ impl SignatureContext<'_> {
373363
let access_key = auth_v2.access_key;
374364
let secret_key = auth.get_secret_key(access_key).await?;
375365

376-
let vh_bucket = self.v2_virtual_hosted_bucket();
377-
378-
let string_to_sign =
379-
sig_v2::create_string_to_sign(sig_v2::Mode::HeaderAuth, method, self.req_uri.path(), self.qs, &self.hs, vh_bucket);
366+
let string_to_sign = sig_v2::create_string_to_sign(
367+
sig_v2::Mode::HeaderAuth,
368+
method,
369+
self.req_uri.path(),
370+
self.qs,
371+
&self.hs,
372+
self.vh_bucket,
373+
);
380374
let signature = sig_v2::calculate_signature(&secret_key, &string_to_sign);
381375

382376
debug!(?string_to_sign, "sig_v2 header_auth");
@@ -405,15 +399,13 @@ impl SignatureContext<'_> {
405399
let access_key = presigned_url.access_key;
406400
let secret_key = auth.get_secret_key(access_key).await?;
407401

408-
let vh_bucket = self.v2_virtual_hosted_bucket();
409-
410402
let string_to_sign = sig_v2::create_string_to_sign(
411403
sig_v2::Mode::PresignedUrl,
412404
self.req_method,
413405
self.req_uri.path(),
414406
self.qs,
415407
&self.hs,
416-
vh_bucket,
408+
self.vh_bucket,
417409
);
418410
let signature = sig_v2::calculate_signature(&secret_key, &string_to_sign);
419411

0 commit comments

Comments
 (0)