Skip to content

Commit 7f20cd2

Browse files
committed
initialize the oidc and http clients only once
- Added OidcState struct to encapsulate OIDC configuration and client. - Refactored OidcMiddleware to utilize OidcState for improved state management. - Updated HTTP client handling in OIDC service methods for better integration with app data. - Enhanced logging for OIDC middleware initialization and request processing.
1 parent e18d03f commit 7f20cd2

File tree

6 files changed

+103
-90
lines changed

6 files changed

+103
-90
lines changed

src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ pub mod webserver;
8383
use crate::app_config::AppConfig;
8484
use crate::filesystem::FileSystem;
8585
use crate::webserver::database::ParsedSqlFile;
86+
use crate::webserver::oidc::OidcState;
8687
use file_cache::FileCache;
8788
use std::path::{Path, PathBuf};
89+
use std::sync::Arc;
8890
use templates::AllTemplates;
8991
use webserver::Database;
9092

@@ -102,6 +104,7 @@ pub struct AppState {
102104
sql_file_cache: FileCache<ParsedSqlFile>,
103105
file_system: FileSystem,
104106
config: AppConfig,
107+
pub oidc_state: Option<Arc<OidcState>>,
105108
}
106109

107110
impl AppState {
@@ -117,12 +120,16 @@ impl AppState {
117120
PathBuf::from("index.sql"),
118121
ParsedSqlFile::new(&db, include_str!("../index.sql"), Path::new("index.sql")),
119122
);
123+
124+
let oidc_state = crate::webserver::oidc::initialize_oidc_state(config).await?;
125+
120126
Ok(AppState {
121127
db,
122128
all_templates,
123129
sql_file_cache,
124130
file_system,
125131
config: config.clone(),
132+
oidc_state,
126133
})
127134
}
128135
}

src/template_helpers.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,15 +628,15 @@ mod tests {
628628
const ESCAPED_UNSAFE_MARKUP: &str = "&lt;table&gt;&lt;tr&gt;&lt;td&gt;";
629629
#[test]
630630
fn test_html_blocks_with_various_settings() {
631-
let helper = MarkdownHelper::default();
632-
let content = contents();
633-
634631
struct TestCase {
635632
name: &'static str,
636633
preset: Option<Value>,
637634
expected_output: Result<&'static str, String>,
638635
}
639636

637+
let helper = MarkdownHelper::default();
638+
let content = contents();
639+
640640
let test_cases = [
641641
TestCase {
642642
name: "default settings",

src/webserver/http.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use actix_web::{
1919
};
2020
use actix_web::{HttpResponseBuilder, ResponseError};
2121

22+
use super::http_client::make_http_client;
2223
use super::https::make_auto_rustls_config;
2324
use super::oidc::OidcMiddleware;
2425
use super::response_writer::ResponseWriter;
@@ -478,6 +479,7 @@ pub fn create_app(
478479
middleware::TrailingSlash::MergeOnly,
479480
))
480481
.app_data(payload_config(&app_state))
482+
.app_data(make_http_client(&app_state.config))
481483
.app_data(form_config(&app_state))
482484
.app_data(app_state)
483485
}

src/webserver/http_client.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use actix_web::dev::ServiceRequest;
12
use anyhow::{anyhow, Context};
23
use std::sync::OnceLock;
34

@@ -10,7 +11,7 @@ pub fn make_http_client(config: &crate::app_config::AppConfig) -> anyhow::Result
1011
log::debug!("Loading native certificates because system_root_ca_certificates is enabled");
1112
let certs = rustls_native_certs::load_native_certs()
1213
.with_context(|| "Initial native certificates load failed")?;
13-
log::info!("Loaded {} native certificates", certs.len());
14+
log::debug!("Loaded {} native HTTPS client certificates", certs.len());
1415
let mut roots = rustls::RootCertStore::empty();
1516
for cert in certs {
1617
log::trace!("Adding native certificate to root store: {cert:?}");
@@ -43,3 +44,15 @@ pub fn make_http_client(config: &crate::app_config::AppConfig) -> anyhow::Result
4344
log::debug!("Created HTTP client");
4445
Ok(client)
4546
}
47+
48+
pub(crate) fn get_http_client_from_appdata(
49+
request: &ServiceRequest,
50+
) -> anyhow::Result<&awc::Client> {
51+
if let Some(result) = request.app_data::<anyhow::Result<awc::Client>>() {
52+
result
53+
.as_ref()
54+
.map_err(|e| anyhow!("HTTP client initialization failed: {e}"))
55+
} else {
56+
Err(anyhow!("HTTP client not found in app data"))
57+
}
58+
}

src/webserver/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pub use error_with_status::ErrorWithStatus;
4343

4444
pub use database::make_placeholder;
4545
pub use database::migrations::apply;
46-
mod oidc;
46+
pub mod oidc;
4747
pub mod response_writer;
4848
pub mod routing;
4949
mod static_content;

src/webserver/oidc.rs

Lines changed: 76 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use std::{future::Future, pin::Pin, rc::Rc, str::FromStr, sync::Arc};
1+
use std::future::ready;
2+
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
23

4+
use crate::webserver::http_client::get_http_client_from_appdata;
35
use crate::{app_config::AppConfig, AppState};
46
use actix_web::{
7+
body::BoxBody,
58
cookie::Cookie,
69
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
710
middleware::Condition,
@@ -21,6 +24,8 @@ use serde::{Deserialize, Serialize};
2124

2225
use super::http_client::make_http_client;
2326

27+
type LocalBoxFuture<T> = Pin<Box<dyn Future<Output = T> + 'static>>;
28+
2429
const SQLPAGE_AUTH_COOKIE_NAME: &str = "sqlpage_auth";
2530
const SQLPAGE_REDIRECT_URI: &str = "/sqlpage/oidc_callback";
2631
const SQLPAGE_STATE_COOKIE_NAME: &str = "sqlpage_oidc_state";
@@ -83,45 +88,54 @@ fn get_app_host(config: &AppConfig) -> String {
8388
host
8489
}
8590

91+
pub struct OidcState {
92+
pub config: Arc<OidcConfig>,
93+
pub client: Arc<OidcClient>,
94+
}
95+
96+
pub async fn initialize_oidc_state(
97+
app_config: &AppConfig,
98+
) -> anyhow::Result<Option<Arc<OidcState>>> {
99+
let oidc_cfg = match OidcConfig::try_from(app_config) {
100+
Ok(c) => Arc::new(c),
101+
Err(None) => return Ok(None), // OIDC not configured
102+
Err(Some(e)) => return Err(anyhow::anyhow!(e)),
103+
};
104+
105+
let http_client = make_http_client(app_config)?;
106+
let provider_metadata =
107+
discover_provider_metadata(&http_client, oidc_cfg.issuer_url.clone()).await?;
108+
let client = make_oidc_client(&oidc_cfg, provider_metadata)?;
109+
110+
Ok(Some(Arc::new(OidcState {
111+
config: oidc_cfg,
112+
client: Arc::new(client),
113+
})))
114+
}
115+
86116
pub struct OidcMiddleware {
87-
pub config: Option<Arc<OidcConfig>>,
88-
app_state: web::Data<AppState>,
117+
oidc_state: Option<Arc<OidcState>>,
89118
}
90119

91120
impl OidcMiddleware {
121+
#[must_use]
92122
pub fn new(app_state: &web::Data<AppState>) -> Condition<Self> {
93-
let config = OidcConfig::try_from(&app_state.config);
94-
match &config {
95-
Ok(config) => {
96-
log::debug!("Setting up OIDC with issuer: {}", config.issuer_url);
97-
}
98-
Err(Some(err)) => {
99-
log::error!("Invalid OIDC configuration: {err}");
100-
}
101-
Err(None) => {
102-
log::debug!("No OIDC configuration provided, skipping middleware.");
103-
}
104-
}
105-
let config = config.ok().map(Arc::new);
106-
Condition::new(
107-
config.is_some(),
108-
Self {
109-
config,
110-
app_state: web::Data::clone(app_state),
111-
},
112-
)
123+
let oidc_state = app_state.oidc_state.clone();
124+
Condition::new(oidc_state.is_some(), Self { oidc_state })
113125
}
114126
}
115127

116128
async fn discover_provider_metadata(
117-
http_client: &AwcHttpClient,
129+
http_client: &awc::Client,
118130
issuer_url: IssuerUrl,
119131
) -> anyhow::Result<openidconnect::core::CoreProviderMetadata> {
120132
log::debug!("Discovering provider metadata for {issuer_url}");
121-
let provider_metadata =
122-
openidconnect::core::CoreProviderMetadata::discover_async(issuer_url, http_client)
123-
.await
124-
.with_context(|| "Failed to discover OIDC provider metadata".to_string())?;
133+
let provider_metadata = openidconnect::core::CoreProviderMetadata::discover_async(
134+
issuer_url,
135+
&AwcHttpClient::from_client(http_client),
136+
)
137+
.await
138+
.with_context(|| "Failed to discover OIDC provider metadata".to_string())?;
125139
log::debug!("Provider metadata discovered: {provider_metadata:?}");
126140
Ok(provider_metadata)
127141
}
@@ -135,52 +149,28 @@ where
135149
type Error = Error;
136150
type InitError = ();
137151
type Transform = OidcService<S>;
138-
type Future = Pin<Box<dyn Future<Output = Result<Self::Transform, Self::InitError>> + 'static>>;
152+
type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
139153

140154
fn new_transform(&self, service: S) -> Self::Future {
141-
let config = self.config.clone();
142-
let app_state = web::Data::clone(&self.app_state);
143-
Box::pin(async move {
144-
match config {
145-
Some(config) => Ok(OidcService::new(service, &app_state, Arc::clone(&config))
146-
.await
147-
.map_err(|err| {
148-
log::error!(
149-
"Error creating OIDC service with issuer: {}: {err:?}",
150-
config.issuer_url
151-
);
152-
})?),
153-
None => Err(()),
154-
}
155-
})
155+
match &self.oidc_state {
156+
Some(state) => ready(Ok(OidcService::new(service, Arc::clone(state)))),
157+
None => ready(Err(())),
158+
}
156159
}
157160
}
158161

159162
#[derive(Clone)]
160163
pub struct OidcService<S> {
161164
service: S,
162-
config: Arc<OidcConfig>,
163-
oidc_client: Arc<OidcClient>,
164-
http_client: Rc<AwcHttpClient>,
165+
oidc_state: Arc<OidcState>,
165166
}
166167

167168
impl<S> OidcService<S> {
168-
pub async fn new(
169-
service: S,
170-
app_state: &web::Data<AppState>,
171-
config: Arc<OidcConfig>,
172-
) -> anyhow::Result<Self> {
173-
let issuer_url = config.issuer_url.clone();
174-
let http_client = AwcHttpClient::new(&app_state.config)?;
175-
let provider_metadata = discover_provider_metadata(&http_client, issuer_url).await?;
176-
let client: OidcClient = make_oidc_client(&config, provider_metadata)
177-
.with_context(|| format!("Unable to create OIDC client with config: {config:?}"))?;
178-
Ok(Self {
169+
pub fn new(service: S, oidc_state: Arc<OidcState>) -> Self {
170+
Self {
179171
service,
180-
config,
181-
oidc_client: Arc::new(client),
182-
http_client: Rc::new(http_client),
183-
})
172+
oidc_state,
173+
}
184174
}
185175

186176
fn handle_unauthenticated_request(
@@ -195,22 +185,24 @@ impl<S> OidcService<S> {
195185

196186
log::debug!("Redirecting to OIDC provider");
197187

198-
let response =
199-
build_auth_provider_redirect_response(&self.oidc_client, &self.config, &request);
188+
let response = build_auth_provider_redirect_response(
189+
&self.oidc_state.client,
190+
&self.oidc_state.config,
191+
&request,
192+
);
200193
Box::pin(async move { Ok(request.into_response(response)) })
201194
}
202195

203196
fn handle_oidc_callback(
204197
&self,
205198
request: ServiceRequest,
206199
) -> LocalBoxFuture<Result<ServiceResponse<BoxBody>, Error>> {
207-
let oidc_client = Arc::clone(&self.oidc_client);
208-
let http_client = Rc::clone(&self.http_client);
209-
let oidc_config = Arc::clone(&self.config);
200+
let oidc_client = Arc::clone(&self.oidc_state.client);
201+
let oidc_config = Arc::clone(&self.oidc_state.config);
210202

211203
Box::pin(async move {
212204
let query_string = request.query_string();
213-
match process_oidc_callback(&oidc_client, &http_client, query_string, &request).await {
205+
match process_oidc_callback(&oidc_client, query_string, &request).await {
214206
Ok(response) => Ok(request.into_response(response)),
215207
Err(e) => {
216208
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
@@ -223,9 +215,6 @@ impl<S> OidcService<S> {
223215
}
224216
}
225217

226-
type LocalBoxFuture<T> = Pin<Box<dyn Future<Output = T> + 'static>>;
227-
use actix_web::body::BoxBody;
228-
229218
impl<S> Service<ServiceRequest> for OidcService<S>
230219
where
231220
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error>,
@@ -238,8 +227,11 @@ where
238227
forward_ready!(service);
239228

240229
fn call(&self, request: ServiceRequest) -> Self::Future {
241-
log::debug!("Started OIDC middleware with config: {:?}", self.config);
242-
let oidc_client = Arc::clone(&self.oidc_client);
230+
log::debug!(
231+
"Started OIDC middleware with config: {:?}",
232+
self.oidc_state.config
233+
);
234+
let oidc_client = Arc::clone(&self.oidc_state.client);
243235
match get_sqlpage_auth_cookie(&oidc_client, &request) {
244236
Ok(Some(cookie)) => {
245237
log::trace!("Found SQLPage auth cookie: {cookie}");
@@ -269,10 +261,11 @@ where
269261

270262
async fn process_oidc_callback(
271263
oidc_client: &OidcClient,
272-
http_client: &AwcHttpClient,
273264
query_string: &str,
274265
request: &ServiceRequest,
275266
) -> anyhow::Result<HttpResponse> {
267+
let http_client = get_http_client_from_appdata(request)?;
268+
276269
let state = get_state_from_cookie(request)?;
277270

278271
let params = Query::<OidcCallbackParams>::from_query(query_string)
@@ -299,15 +292,14 @@ async fn process_oidc_callback(
299292

300293
async fn exchange_code_for_token(
301294
oidc_client: &OidcClient,
302-
http_client: &AwcHttpClient,
295+
http_client: &awc::Client,
303296
oidc_callback_params: OidcCallbackParams,
304297
) -> anyhow::Result<openidconnect::core::CoreTokenResponse> {
305-
// TODO: Verify the state matches the expected CSRF token
306298
let token_response = oidc_client
307299
.exchange_code(openidconnect::AuthorizationCode::new(
308300
oidc_callback_params.code,
309301
))?
310-
.request_async(http_client)
302+
.request_async(&AwcHttpClient::from_client(http_client))
311303
.await?;
312304
Ok(token_response)
313305
}
@@ -376,19 +368,18 @@ fn get_sqlpage_auth_cookie(
376368
Ok(Some(cookie_value))
377369
}
378370

379-
pub struct AwcHttpClient {
380-
client: Client,
371+
pub struct AwcHttpClient<'c> {
372+
client: &'c awc::Client,
381373
}
382374

383-
impl AwcHttpClient {
384-
pub fn new(app_config: &AppConfig) -> anyhow::Result<Self> {
385-
Ok(Self {
386-
client: make_http_client(app_config)?,
387-
})
375+
impl<'c> AwcHttpClient<'c> {
376+
#[must_use]
377+
pub fn from_client(client: &'c awc::Client) -> Self {
378+
Self { client }
388379
}
389380
}
390381

391-
impl<'c> AsyncHttpClient<'c> for AwcHttpClient {
382+
impl<'c> AsyncHttpClient<'c> for AwcHttpClient<'c> {
392383
type Error = AwcWrapperError;
393384
type Future =
394385
Pin<Box<dyn Future<Output = Result<openidconnect::HttpResponse, Self::Error>> + 'c>>;

0 commit comments

Comments
 (0)