Skip to content

Commit ec9aa56

Browse files
authored
implemented custom headers passing in auth layer (#1472)
1 parent 4b21878 commit ec9aa56

File tree

13 files changed

+176
-141
lines changed

13 files changed

+176
-141
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
pub(crate) static AUTH_HEADER: &str = "authorization";
12
pub(crate) static GRPC_AUTH_HEADER: &str = "x-authorization";
23
pub(crate) static GRPC_PROXY_AUTH_HEADER: &str = "x-proxy-authorization";

libsql-server/src/auth/mod.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ impl Auth {
2727
}
2828
}
2929

30-
pub fn authenticate(
31-
&self,
32-
context: Result<UserAuthContext, AuthError>,
33-
) -> Result<Authenticated, AuthError> {
30+
pub fn authenticate(&self, context: UserAuthContext) -> Result<Authenticated, AuthError> {
3431
self.user_strategy.authenticate(context)
3532
}
3633
}

libsql-server/src/auth/parsers.rs

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::auth::{constants::GRPC_AUTH_HEADER, AuthError};
1+
use crate::auth::AuthError;
22

33
use anyhow::{bail, Context as _, Result};
44
use axum::http::HeaderValue;
@@ -41,12 +41,21 @@ pub fn parse_jwt_keys(data: &str) -> Result<Vec<jsonwebtoken::DecodingKey>> {
4141
}
4242
}
4343

44-
pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result<UserAuthContext, AuthError> {
45-
metadata
46-
.get(GRPC_AUTH_HEADER)
47-
.ok_or(AuthError::AuthHeaderNotFound)
48-
.and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii))
49-
.and_then(|t| UserAuthContext::from_auth_str(t))
44+
pub(crate) fn parse_grpc_auth_header(
45+
metadata: &MetadataMap,
46+
required_fields: &Vec<&'static str>,
47+
) -> Result<UserAuthContext> {
48+
let mut context = UserAuthContext::empty();
49+
50+
for field in required_fields.iter() {
51+
metadata
52+
.get(*field)
53+
.ok_or_else(|| AuthError::AuthHeaderNotFound)
54+
.and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii))
55+
.map(|v| context.add_field(field, v.into()))?;
56+
}
57+
58+
Ok(context)
5059
}
5160

5261
pub fn parse_http_auth_header<'a>(
@@ -78,6 +87,7 @@ mod tests {
7887
use hyper::header::AUTHORIZATION;
7988

8089
use crate::auth::authorized::Scopes;
90+
use crate::auth::constants::GRPC_AUTH_HEADER;
8191
use crate::auth::user_auth_strategies::jwt::Token;
8292
use crate::auth::{parse_http_auth_header, parse_jwt_keys, AuthError};
8393

@@ -86,41 +96,25 @@ mod tests {
8696
#[test]
8797
fn parse_grpc_auth_header_returns_valid_context() {
8898
let mut map = tonic::metadata::MetadataMap::new();
89-
map.insert("x-authorization", "bearer 123".parse().unwrap());
90-
let context = parse_grpc_auth_header(&map).unwrap();
91-
assert_eq!(context.scheme().as_ref().unwrap(), "bearer");
92-
assert_eq!(context.token().as_ref().unwrap(), "123");
93-
}
99+
map.insert(GRPC_AUTH_HEADER, "bearer 123".parse().unwrap());
100+
let required_fields = vec!["x-authorization".into()];
101+
let context = parse_grpc_auth_header(&map, &required_fields).unwrap();
94102

95-
#[test]
96-
fn parse_grpc_auth_header_error_no_header() {
97-
let map = tonic::metadata::MetadataMap::new();
98-
let result = parse_grpc_auth_header(&map);
99103
assert_eq!(
100-
result.unwrap_err().to_string(),
101-
"Expected authorization header but none given"
104+
context.get_field("x-authorization"),
105+
Some(&"bearer 123".to_string())
102106
);
103107
}
104108

105109
#[test]
106110
fn parse_grpc_auth_header_error_non_ascii() {
107111
let mut map = tonic::metadata::MetadataMap::new();
108112
map.insert("x-authorization", "bearer I❤NY".parse().unwrap());
109-
let result = parse_grpc_auth_header(&map);
113+
let required_fields = vec!["x-authorization".into()];
114+
let result = parse_grpc_auth_header(&map, &required_fields);
110115
assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header")
111116
}
112117

113-
#[test]
114-
fn parse_grpc_auth_header_error_malformed_auth_str() {
115-
let mut map = tonic::metadata::MetadataMap::new();
116-
map.insert("x-authorization", "bearer123".parse().unwrap());
117-
let result = parse_grpc_auth_header(&map);
118-
assert_eq!(
119-
result.unwrap_err().to_string(),
120-
"Auth string does not conform to '<scheme> <token>' form"
121-
)
122-
}
123-
124118
#[test]
125119
fn parse_http_auth_header_returns_auth_header_param_when_valid() {
126120
assert_eq!(

libsql-server/src/auth/user_auth_strategies/disabled.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ use crate::auth::{AuthError, Authenticated};
44
pub struct Disabled {}
55

66
impl UserAuthStrategy for Disabled {
7-
fn authenticate(
8-
&self,
9-
_context: Result<UserAuthContext, AuthError>,
10-
) -> Result<Authenticated, AuthError> {
7+
fn authenticate(&self, _context: UserAuthContext) -> Result<Authenticated, AuthError> {
118
tracing::trace!("executing disabled auth");
129
Ok(Authenticated::FullAccess)
1310
}
@@ -26,7 +23,7 @@ mod tests {
2623
#[test]
2724
fn authenticates() {
2825
let strategy = Disabled::new();
29-
let context = Ok(UserAuthContext::empty());
26+
let context = UserAuthContext::empty();
3027

3128
assert!(matches!(
3229
strategy.authenticate(context).unwrap(),

libsql-server/src/auth/user_auth_strategies/http_basic.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::auth::{AuthError, Authenticated};
1+
use crate::auth::{
2+
constants::{AUTH_HEADER, GRPC_AUTH_HEADER},
3+
AuthError, Authenticated,
4+
};
25

36
use super::{UserAuthContext, UserAuthStrategy};
47

@@ -7,27 +10,30 @@ pub struct HttpBasic {
710
}
811

912
impl UserAuthStrategy for HttpBasic {
10-
fn authenticate(
11-
&self,
12-
context: Result<UserAuthContext, AuthError>,
13-
) -> Result<Authenticated, AuthError> {
13+
fn authenticate(&self, ctx: UserAuthContext) -> Result<Authenticated, AuthError> {
1414
tracing::trace!("executing http basic auth");
15+
let auth_str = ctx
16+
.get_field(AUTH_HEADER)
17+
.or_else(|| ctx.get_field(GRPC_AUTH_HEADER));
18+
19+
let (_, token) = auth_str
20+
.ok_or(AuthError::AuthHeaderNotFound)
21+
.map(|s| s.split_once(' ').ok_or(AuthError::AuthStringMalformed))
22+
.and_then(|o| o)?;
1523

1624
// NOTE: this naive comparison may leak information about the `expected_value`
1725
// using a timing attack
1826
let expected_value = self.credential.trim_end_matches('=');
19-
20-
let creds_match = match context?.token {
21-
Some(s) => s.contains(expected_value),
22-
None => expected_value.is_empty(),
23-
};
24-
27+
let creds_match = token.contains(expected_value);
2528
if creds_match {
2629
return Ok(Authenticated::FullAccess);
2730
}
28-
2931
Err(AuthError::BasicRejected)
3032
}
33+
34+
fn required_fields(&self) -> Vec<&'static str> {
35+
vec![AUTH_HEADER, GRPC_AUTH_HEADER]
36+
}
3137
}
3238

3339
impl HttpBasic {
@@ -48,7 +54,7 @@ mod tests {
4854

4955
#[test]
5056
fn authenticates_with_valid_credential() {
51-
let context = Ok(UserAuthContext::basic(CREDENTIAL));
57+
let context = UserAuthContext::basic(CREDENTIAL);
5258

5359
assert!(matches!(
5460
strategy().authenticate(context).unwrap(),
@@ -59,7 +65,7 @@ mod tests {
5965
#[test]
6066
fn authenticates_with_valid_trimmed_credential() {
6167
let credential = CREDENTIAL.trim_end_matches('=');
62-
let context = Ok(UserAuthContext::basic(credential));
68+
let context = UserAuthContext::basic(credential);
6369

6470
assert!(matches!(
6571
strategy().authenticate(context).unwrap(),
@@ -69,7 +75,7 @@ mod tests {
6975

7076
#[test]
7177
fn errors_when_credentials_do_not_match() {
72-
let context = Ok(UserAuthContext::basic("abc"));
78+
let context = UserAuthContext::basic("abc");
7379

7480
assert_eq!(
7581
strategy().authenticate(context).unwrap_err(),

libsql-server/src/auth/user_auth_strategies/jwt.rs

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
use chrono::{DateTime, Utc};
22

33
use crate::{
4-
auth::{authenticated::LegacyAuth, AuthError, Authenticated, Authorized, Permission},
4+
auth::{
5+
authenticated::LegacyAuth,
6+
constants::{AUTH_HEADER, GRPC_AUTH_HEADER},
7+
AuthError, Authenticated, Authorized, Permission,
8+
},
59
namespace::NamespaceName,
610
};
711

@@ -12,28 +16,27 @@ pub struct Jwt {
1216
}
1317

1418
impl UserAuthStrategy for Jwt {
15-
fn authenticate(
16-
&self,
17-
context: Result<UserAuthContext, AuthError>,
18-
) -> Result<Authenticated, AuthError> {
19+
fn authenticate(&self, ctx: UserAuthContext) -> Result<Authenticated, AuthError> {
1920
tracing::trace!("executing jwt auth");
21+
let auth_str = ctx
22+
.get_field(AUTH_HEADER)
23+
.or_else(|| ctx.get_field(GRPC_AUTH_HEADER))
24+
.ok_or_else(|| AuthError::AuthHeaderNotFound)?;
2025

21-
let ctx = context?;
22-
23-
let UserAuthContext {
24-
scheme: Some(scheme),
25-
token: Some(token),
26-
} = ctx
27-
else {
28-
return Err(AuthError::HttpAuthHeaderInvalid);
29-
};
26+
let (scheme, token) = auth_str
27+
.split_once(' ')
28+
.ok_or(AuthError::AuthStringMalformed)?;
3029

3130
if !scheme.eq_ignore_ascii_case("bearer") {
3231
return Err(AuthError::HttpAuthHeaderUnsupportedScheme);
3332
}
3433

3534
validate_any_jwt(&self.keys, &token)
3635
}
36+
37+
fn required_fields(&self) -> Vec<&'static str> {
38+
vec![AUTH_HEADER, GRPC_AUTH_HEADER]
39+
}
3740
}
3841

3942
impl Jwt {
@@ -190,7 +193,7 @@ mod tests {
190193
};
191194
let token = encode(&token, &enc);
192195

193-
let context = Ok(UserAuthContext::bearer(token.as_str()));
196+
let context = UserAuthContext::bearer(token.as_str());
194197

195198
assert!(matches!(
196199
strategy(dec).authenticate(context).unwrap(),
@@ -212,7 +215,7 @@ mod tests {
212215
};
213216
let token = encode(&token, &enc);
214217

215-
let context = Ok(UserAuthContext::bearer(token.as_str()));
218+
let context = UserAuthContext::bearer(token.as_str());
216219

217220
let Authenticated::Legacy(a) = strategy(dec).authenticate(context).unwrap() else {
218221
panic!()
@@ -225,7 +228,7 @@ mod tests {
225228
#[test]
226229
fn errors_when_jwt_token_invalid() {
227230
let (_enc, dec) = generate_key_pair();
228-
let context = Ok(UserAuthContext::bearer("abc"));
231+
let context = UserAuthContext::bearer("abc");
229232

230233
assert_eq!(
231234
strategy(dec).authenticate(context).unwrap_err(),
@@ -245,7 +248,7 @@ mod tests {
245248

246249
let token = encode(&token, &enc);
247250

248-
let context = Ok(UserAuthContext::bearer(token.as_str()));
251+
let context = UserAuthContext::bearer(token.as_str());
249252

250253
assert_eq!(
251254
strategy(dec).authenticate(context).unwrap_err(),
@@ -267,7 +270,7 @@ mod tests {
267270

268271
let token = encode(&token, &enc);
269272

270-
let context = Ok(UserAuthContext::bearer(token.as_str()));
273+
let context = UserAuthContext::bearer(token.as_str());
271274

272275
let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else {
273276
panic!()
@@ -304,7 +307,7 @@ mod tests {
304307
for enc in multi_enc.iter() {
305308
let token = encode(&token, &enc);
306309

307-
let context = Ok(UserAuthContext::bearer(token.as_str()));
310+
let context = UserAuthContext::bearer(token.as_str());
308311

309312
let Authenticated::Authorized(a) = strategy.authenticate(context).unwrap() else {
310313
panic!()
@@ -331,7 +334,7 @@ mod tests {
331334
});
332335
let token = encode(&token, &enc);
333336

334-
let context = Ok(UserAuthContext::bearer(token.as_str()));
337+
let context = UserAuthContext::bearer(token.as_str());
335338

336339
assert_eq!(
337340
strategy_with_multiple(multi_dec)
@@ -352,7 +355,7 @@ mod tests {
352355
};
353356
let token = encode(&token, &multi_enc[0]);
354357

355-
let context = Ok(UserAuthContext::bearer(token.as_str()));
358+
let context = UserAuthContext::bearer(token.as_str());
356359

357360
assert_eq!(
358361
strategy_with_multiple(multi_dec)
@@ -373,7 +376,7 @@ mod tests {
373376
};
374377
let token = encode(&token, &multi_enc[2]);
375378

376-
let context = Ok(UserAuthContext::bearer(token.as_str()));
379+
let context = UserAuthContext::bearer(token.as_str());
377380

378381
assert_eq!(
379382
strategy_with_multiple(multi_dec)

0 commit comments

Comments
 (0)