Skip to content

Commit 019f5e1

Browse files
committed
wip: implement OAuth2 request guard and userinfo endpoint for JWT validation
1 parent 3918824 commit 019f5e1

File tree

5 files changed

+99
-4
lines changed

5 files changed

+99
-4
lines changed

src/visualization/jwt/jwt_validator.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
use anyhow::{anyhow, Result};
5050
use chrono::{DateTime, TimeZone, Utc};
5151
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
52+
use log::debug;
5253
use serde::{Deserialize, Serialize};
5354
use std::collections::HashMap;
5455

@@ -299,13 +300,15 @@ impl JwtValidator {
299300
.hmac_key
300301
.as_ref()
301302
.ok_or_else(|| anyhow!("HS256 key not configured"))?;
303+
debug!("Using HS256 key for validation");
302304
(key, Algorithm::HS256)
303305
}
304306
Algorithm::RS256 => {
305307
let key = self
306308
.rs256_key
307309
.as_ref()
308310
.ok_or_else(|| anyhow!("RS256 key not configured"))?;
311+
debug!("Using RS256 key for validation");
309312
(key, Algorithm::RS256)
310313
}
311314
_ => return Err(anyhow!("Unsupported JWT algorithm: {:?}", alg)),
@@ -314,13 +317,24 @@ impl JwtValidator {
314317
validation.validate_exp = true;
315318
validation.validate_nbf = true;
316319
if let Some(ref issuer) = self.expected_issuer {
320+
debug!("Validating issuer: {}", issuer);
317321
validation.set_issuer(&[issuer]);
318322
}
319-
if let Some(ref aud) = self.expected_audience {
323+
324+
// Get expected audience from config or set a default
325+
let expected_audience = Some("LaserSmartClient"); // TODO: replace with config if needed
326+
if let Some(ref aud) = expected_audience {
327+
debug!("Validating audience: {}", aud);
320328
validation.set_audience(&[aud]);
321329
}
330+
331+
// TODO: remove this debug log in production
332+
validation.validate_aud = false;
322333
let token_data = decode::<JwtClaims>(token, key, &validation)
323-
.map_err(|e| anyhow!("JWT validation failed: {}", e))?;
334+
.map_err(|e|{
335+
debug!("JWT validation error: {}", e);
336+
anyhow!("JWT validation failed: {}", e)
337+
})?;
324338
let now = Utc::now();
325339
let exp_time = Utc
326340
.timestamp_opt(token_data.claims.exp, 0)

src/visualization/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ pub mod server;
130130
/// OIDC discovery and configuration
131131
pub mod oidc;
132132

133+
/// Oauth2 Request Guard
134+
pub mod oauth_guard;
135+
133136
use crate::{config::Config, AnalysisResult};
134137
use anyhow::Result;
135138
use base64::{self, Engine};

src/visualization/oauth_guard.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (c) 2025 Ronan LE MEILLAT, SCTG Development
2+
// This file is part of the rust-photoacoustic project and is licensed under the
3+
// SCTG Development Non-Commercial License v1.0 (see LICENSE.md for details).
4+
5+
//! Rocket request guard for validating Bearer tokens using JwtValidator (HS256/RS256)
6+
7+
use rocket::http::Status;
8+
use rocket::request::{FromRequest, Outcome, Request};
9+
use rocket::State;
10+
use crate::visualization::jwt::jwt_validator::{JwtValidator, UserInfo};
11+
use crate::visualization::oidc_auth::OxideState;
12+
use base64::Engine;
13+
14+
/// Request guard for extracting and validating a Bearer JWT from the Authorization header
15+
pub struct OAuthBearer(pub UserInfo);
16+
17+
#[rocket::async_trait]
18+
impl<'r> FromRequest<'r> for OAuthBearer {
19+
type Error = (Status, &'static str);
20+
21+
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
22+
// Get the Authorization header
23+
let auth_header = request.headers().get_one("Authorization");
24+
25+
if let Some(header) = auth_header {
26+
if let Some(token) = header.strip_prefix("Bearer ") {
27+
// Get the OxideState from Rocket state
28+
let state = match request.guard::<&State<OxideState>>().await {
29+
Outcome::Success(state) => state,
30+
_ => return Outcome::Error((Status::InternalServerError,(Status::InternalServerError, "Missing state"))),
31+
};
32+
// Build JwtValidator from state (supporting both HS256 and RS256)
33+
let hmac_secret = state.hmac_secret.as_bytes();
34+
let rs256_public_key = if !state.rs256_public_key.is_empty() {
35+
base64::engine::general_purpose::STANDARD.decode(&state.rs256_public_key).ok()
36+
} else {
37+
None
38+
};
39+
40+
let mut validator = match rs256_public_key {
41+
Some(ref pem) => JwtValidator::new(Some(hmac_secret), Some(pem)),
42+
None => JwtValidator::new(Some(hmac_secret), None),
43+
};
44+
match validator {
45+
Ok(validator) => {
46+
match validator.get_user_info(token) {
47+
Ok(user_info) => Outcome::Success(OAuthBearer(user_info)),
48+
Err(_) => Outcome::Error((Status::Unauthorized,(Status::Unauthorized, "Invalid token"))),
49+
}
50+
}
51+
Err(_) => Outcome::Error((Status::InternalServerError,(Status::InternalServerError, "Validator error"))),
52+
}
53+
} else {
54+
Outcome::Error((Status::Unauthorized,(Status::Unauthorized, "Missing Bearer token")))
55+
}
56+
} else {
57+
Outcome::Error((Status::Unauthorized,(Status::Unauthorized, "Missing Authorization header")))
58+
}
59+
}
60+
}

src/visualization/oidc_auth.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ use oxide_auth::primitives::prelude::*;
5555
use oxide_auth::primitives::registrar::RegisteredUrl;
5656
use oxide_auth_rocket;
5757
use oxide_auth_rocket::{OAuthFailure, OAuthRequest, OAuthResponse};
58-
use rand::rand_core::le;
5958
use rocket::figment::Figment;
6059
use rocket::State;
6160
use rocket::{get, post};
@@ -64,6 +63,7 @@ use serde_json::json;
6463
use url::Url;
6564

6665
use super::jwt::JwtIssuer;
66+
use super::oauth_guard::OAuthBearer;
6767

6868
use crate::config::{AccessConfig, User, USER_SESSION_SEPARATOR};
6969
use base64::Engine;
@@ -830,6 +830,23 @@ pub async fn refresh<'r>(
830830
.map_err(|err| err.pack::<OAuthFailure>())
831831
}
832832

833+
/// Openid userinfo endpoint
834+
/// Accessed via `GET /userinfo`
835+
/// This endpoint returns user information based on the access token provided in the Authorization header.
836+
/// It requires a valid JWT access token to be present in the request Authorization header.
837+
#[get("/userinfo")]
838+
pub async fn userinfo(
839+
bearer: OAuthBearer,
840+
state: &State<OxideState>,
841+
) -> Result<rocket::serde::json::Json<User>, OAuthFailure> {
842+
// Return the authenticated user's information
843+
debug!("Userinfo endpoint accessed with bearer token");
844+
Ok(rocket::serde::json::Json(User {
845+
user: "toto".to_string(), // Username is not returned in userinfo
846+
pass: String::new(), // Password is not returned in userinfo
847+
permissions: vec!["read:api".to_string(), "write:api".to_string()],
848+
}))
849+
}
833850
impl OxideState {
834851
/// Create a preconfigured OxideState with default settings
835852
///

src/visualization/server.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7373
use std::ops::Deref;
7474
use std::path::PathBuf;
7575

76-
use super::oidc_auth::{login, OxideState};
76+
use super::oidc_auth::{login, userinfo, OxideState};
7777

7878
/// Static directory containing the web client files
7979
///
@@ -496,6 +496,7 @@ pub async fn build_rocket(figment: Figment) -> Rocket<Build> {
496496
authorize,
497497
authorize_consent,
498498
login,
499+
userinfo,
499500
token,
500501
refresh,
501502
super::introspection::introspect,

0 commit comments

Comments
 (0)