Skip to content

Commit a39be58

Browse files
committed
feat(http): implement forward extension type logic
1 parent d87e502 commit a39be58

File tree

5 files changed

+96
-27
lines changed

5 files changed

+96
-27
lines changed

htsget-config/src/config/advanced/auth/mod.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -415,19 +415,15 @@ mod tests {
415415
config.authorization_url().unwrap(),
416416
&UrlOrStatic::Url("https://www.example.com".parse::<Uri>().unwrap())
417417
);
418-
assert!(
419-
config.passthrough_auth()
420-
);
421-
assert_eq!(
422-
config.forward_headers(),
423-
["header".to_string()]
424-
);
425-
assert!(
426-
config.forward_endpoint_type()
427-
);
418+
assert!(config.passthrough_auth());
419+
assert_eq!(config.forward_headers(), ["header".to_string()]);
420+
assert!(config.forward_endpoint_type());
428421
assert_eq!(
429422
config.forward_extensions(),
430-
[ForwardExtensions::new("$.extension".to_string(), "Extension".to_string())]
423+
[ForwardExtensions::new(
424+
"$.extension".to_string(),
425+
"Extension".to_string()
426+
)]
431427
);
432428
}
433429

htsget-http/src/error.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use http::StatusCode;
2+
use http::header::{InvalidHeaderName, InvalidHeaderValue};
23
use serde::Serialize;
34
use thiserror::Error;
45

@@ -84,3 +85,15 @@ impl From<HtsGetSearchError> for HtsGetError {
8485
}
8586
}
8687
}
88+
89+
impl From<InvalidHeaderName> for HtsGetError {
90+
fn from(err: InvalidHeaderName) -> Self {
91+
Self::InternalError(err.to_string())
92+
}
93+
}
94+
95+
impl From<InvalidHeaderValue> for HtsGetError {
96+
fn from(err: InvalidHeaderValue) -> Self {
97+
Self::InternalError(err.to_string())
98+
}
99+
}

htsget-http/src/http_core.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ async fn authorize(
3737
queries: &mut [Query],
3838
auth: Option<(TokenData<Value>, Auth)>,
3939
extensions: Option<Value>,
40+
endpoint: &Endpoint,
4041
) -> Result<Option<AuthorizationRestrictions>> {
4142
if let Some((_, mut auth)) = auth {
4243
let _rules = auth
43-
.validate_authorization(headers, path, queries, extensions)
44+
.validate_authorization(headers, path, queries, extensions, endpoint)
4445
.await?;
4546
cfg_if! {
4647
if #[cfg(feature = "experimental")] {
@@ -77,7 +78,15 @@ pub async fn get(
7778

7879
let format = match_format_from_query(&endpoint, request.query())?;
7980
let mut query = vec![convert_to_query(request, format)?];
80-
let rules = authorize(&headers, &path, query.as_mut_slice(), auth, extensions).await?;
81+
let rules = authorize(
82+
&headers,
83+
&path,
84+
query.as_mut_slice(),
85+
auth,
86+
extensions,
87+
&endpoint,
88+
)
89+
.await?;
8190

8291
debug!(endpoint = ?endpoint, query = ?query, "getting GET response");
8392

@@ -137,7 +146,15 @@ pub async fn post(
137146
}
138147

139148
let mut queries = body.get_queries(request, &endpoint)?;
140-
let rules = authorize(&headers, &path, queries.as_mut_slice(), auth, extensions).await?;
149+
let rules = authorize(
150+
&headers,
151+
&path,
152+
queries.as_mut_slice(),
153+
auth,
154+
extensions,
155+
&endpoint,
156+
)
157+
.await?;
141158

142159
debug!(endpoint = ?endpoint, queries = ?queries, "getting POST response");
143160

htsget-http/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ use query_builder::QueryBuilder;
99
pub use service_info::get_service_info_json;
1010
pub use service_info::{Htsget, ServiceInfo, Type};
1111
use std::collections::HashMap;
12-
use std::result;
12+
use std::fmt::{Display, Formatter};
1313
use std::str::FromStr;
14+
use std::{fmt, result};
1415

1516
pub mod error;
1617
pub mod http_core;
@@ -39,6 +40,15 @@ impl FromStr for Endpoint {
3940
}
4041
}
4142

43+
impl Display for Endpoint {
44+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
45+
match self {
46+
Self::Reads => write!(f, "reads"),
47+
Self::Variants => write!(f, "variants"),
48+
}
49+
}
50+
}
51+
4252
/// Match the format from a query parameter.
4353
pub fn match_format_from_query(
4454
endpoint: &Endpoint,

htsget-http/src/middleware/auth.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
//! The htsget authorization middleware.
22
//!
33
4-
use crate::HtsGetError;
54
use crate::error::Result as HtsGetResult;
65
use crate::middleware::error::Error::AuthBuilderError;
76
use crate::middleware::error::Result;
7+
use crate::{Endpoint, HtsGetError};
88
use cfg_if::cfg_if;
99
use headers::authorization::Bearer;
1010
use headers::{Authorization, Header};
@@ -72,6 +72,7 @@ impl Debug for Auth {
7272
}
7373

7474
const FORWARD_HEADER_PREFIX: &str = "Htsget-Context-";
75+
const ENDPOINT_TYPE_NAME: &str = "Endpoint-Type";
7576

7677
impl Auth {
7778
/// Get the config for this auth layer instance.
@@ -134,6 +135,7 @@ impl Auth {
134135
&self,
135136
request_headers: &HeaderMap,
136137
request_extensions: Option<Value>,
138+
request_endpoint: &Endpoint,
137139
) -> HtsGetResult<HeaderMap> {
138140
let mut forwarded_headers = if self.config.passthrough_auth() {
139141
let auth_header = request_headers
@@ -186,14 +188,20 @@ impl Auth {
186188
})?;
187189

188190
let header_name =
189-
HeaderName::from_str(&format!("{}{}", FORWARD_HEADER_PREFIX, extension.name()))
190-
.map_err(|err| HtsGetError::InternalError(err.to_string()))?;
191-
let value = HeaderValue::from_str(value)
192-
.map_err(|err| HtsGetError::InternalError(err.to_string()))?;
191+
HeaderName::from_str(&format!("{}{}", FORWARD_HEADER_PREFIX, extension.name()))?;
192+
let value = HeaderValue::from_str(value)?;
193193
forwarded_headers.insert(header_name, value);
194194
}
195195
}
196196

197+
if self.config.forward_endpoint_type() {
198+
let header_name =
199+
HeaderName::from_str(&format!("{}{}", FORWARD_HEADER_PREFIX, ENDPOINT_TYPE_NAME))?;
200+
let value = HeaderValue::from_str(&request_endpoint.to_string())?;
201+
202+
forwarded_headers.insert(header_name, value);
203+
}
204+
197205
Ok(forwarded_headers)
198206
}
199207

@@ -204,10 +212,12 @@ impl Auth {
204212
&mut self,
205213
headers: &HeaderMap,
206214
request_extensions: Option<Value>,
215+
request_endpoint: &Endpoint,
207216
) -> HtsGetResult<Option<AuthorizationRestrictions>> {
208217
match self.config.authorization_url() {
209218
Some(UrlOrStatic::Url(uri)) => {
210-
let forwarded_headers = self.forwarded_headers(headers, request_extensions)?;
219+
let forwarded_headers =
220+
self.forwarded_headers(headers, request_extensions, request_endpoint)?;
211221

212222
self
213223
.fetch_from_url(&uri.to_string(), forwarded_headers)
@@ -408,9 +418,10 @@ impl Auth {
408418
path: &str,
409419
queries: &mut [Query],
410420
request_extensions: Option<Value>,
421+
endpoint: &Endpoint,
411422
) -> HtsGetResult<Option<AuthorizationRestrictions>> {
412423
let restrictions = self
413-
.query_authorization_service(headers, request_extensions)
424+
.query_authorization_service(headers, request_extensions, endpoint)
414425
.await?;
415426

416427
if let Some(restrictions) = restrictions {
@@ -601,7 +612,9 @@ mod tests {
601612
("Custom1".parse().unwrap(), "Value".parse().unwrap()),
602613
("Custom2".parse().unwrap(), "Value".parse().unwrap()),
603614
]);
604-
let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
615+
let forwarded_headers = result
616+
.forwarded_headers(&request_headers, None, &Endpoint::Reads)
617+
.unwrap();
605618
assert_eq!(
606619
forwarded_headers,
607620
HeaderMap::from_iter([
@@ -624,7 +637,9 @@ mod tests {
624637
.unwrap();
625638
let result = AuthBuilder::default().with_config(config).build().unwrap();
626639

627-
let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
640+
let forwarded_headers = result
641+
.forwarded_headers(&request_headers, None, &Endpoint::Reads)
642+
.unwrap();
628643
assert_eq!(
629644
forwarded_headers,
630645
HeaderMap::from_iter([
@@ -648,12 +663,13 @@ mod tests {
648663
let config = builder
649664
.clone()
650665
.forward_headers(vec!["Custom1".to_string()])
651-
.passthrough_auth(false)
652666
.build()
653667
.unwrap();
654668
let result = AuthBuilder::default().with_config(config).build().unwrap();
655669

656-
let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
670+
let forwarded_headers = result
671+
.forwarded_headers(&request_headers, None, &Endpoint::Reads)
672+
.unwrap();
657673
assert_eq!(
658674
forwarded_headers,
659675
HeaderMap::from_iter([(
@@ -663,11 +679,11 @@ mod tests {
663679
);
664680

665681
let config = builder
682+
.clone()
666683
.forward_extensions(vec![ForwardExtensions::new(
667684
"$.Key".to_string(),
668685
"Custom1".to_string(),
669686
)])
670-
.passthrough_auth(false)
671687
.build()
672688
.unwrap();
673689
let result = AuthBuilder::default().with_config(config).build().unwrap();
@@ -678,6 +694,7 @@ mod tests {
678694
Some(json!({
679695
"Key": "Value"
680696
})),
697+
&Endpoint::Reads,
681698
)
682699
.unwrap();
683700
assert_eq!(
@@ -687,6 +704,22 @@ mod tests {
687704
"Value".parse().unwrap()
688705
),])
689706
);
707+
708+
let config = builder.forward_endpoint_type(true).build().unwrap();
709+
let result = AuthBuilder::default().with_config(config).build().unwrap();
710+
711+
let forwarded_headers = result
712+
.forwarded_headers(&request_headers, None, &Endpoint::Variants)
713+
.unwrap();
714+
assert_eq!(
715+
forwarded_headers,
716+
HeaderMap::from_iter([(
717+
format!("{}{}", FORWARD_HEADER_PREFIX, ENDPOINT_TYPE_NAME)
718+
.parse()
719+
.unwrap(),
720+
"variants".parse().unwrap()
721+
),])
722+
);
690723
}
691724

692725
#[test]

0 commit comments

Comments
 (0)