1- use std:: { future:: Future , pin:: Pin , rc:: Rc , str:: FromStr , sync:: Arc } ;
1+ use std:: future:: ready;
2+ use std:: { future:: Future , pin:: Pin , str:: FromStr , sync:: Arc } ;
23
4+ use crate :: webserver:: http_client:: get_http_client_from_appdata;
35use crate :: { app_config:: AppConfig , AppState } ;
46use actix_web:: {
7+ body:: BoxBody ,
58 cookie:: Cookie ,
69 dev:: { forward_ready, Service , ServiceRequest , ServiceResponse , Transform } ,
710 middleware:: Condition ,
@@ -21,6 +24,8 @@ use serde::{Deserialize, Serialize};
2124
2225use super :: http_client:: make_http_client;
2326
27+ type LocalBoxFuture < T > = Pin < Box < dyn Future < Output = T > + ' static > > ;
28+
2429const SQLPAGE_AUTH_COOKIE_NAME : & str = "sqlpage_auth" ;
2530const SQLPAGE_REDIRECT_URI : & str = "/sqlpage/oidc_callback" ;
2631const SQLPAGE_STATE_COOKIE_NAME : & str = "sqlpage_oidc_state" ;
@@ -83,45 +88,54 @@ fn get_app_host(config: &AppConfig) -> String {
8388 host
8489}
8590
91+ pub struct OidcState {
92+ pub config : Arc < OidcConfig > ,
93+ pub client : Arc < OidcClient > ,
94+ }
95+
96+ pub async fn initialize_oidc_state (
97+ app_config : & AppConfig ,
98+ ) -> anyhow:: Result < Option < Arc < OidcState > > > {
99+ let oidc_cfg = match OidcConfig :: try_from ( app_config) {
100+ Ok ( c) => Arc :: new ( c) ,
101+ Err ( None ) => return Ok ( None ) , // OIDC not configured
102+ Err ( Some ( e) ) => return Err ( anyhow:: anyhow!( e) ) ,
103+ } ;
104+
105+ let http_client = make_http_client ( app_config) ?;
106+ let provider_metadata =
107+ discover_provider_metadata ( & http_client, oidc_cfg. issuer_url . clone ( ) ) . await ?;
108+ let client = make_oidc_client ( & oidc_cfg, provider_metadata) ?;
109+
110+ Ok ( Some ( Arc :: new ( OidcState {
111+ config : oidc_cfg,
112+ client : Arc :: new ( client) ,
113+ } ) ) )
114+ }
115+
86116pub struct OidcMiddleware {
87- pub config : Option < Arc < OidcConfig > > ,
88- app_state : web:: Data < AppState > ,
117+ oidc_state : Option < Arc < OidcState > > ,
89118}
90119
91120impl OidcMiddleware {
121+ #[ must_use]
92122 pub fn new ( app_state : & web:: Data < AppState > ) -> Condition < Self > {
93- let config = OidcConfig :: try_from ( & app_state. config ) ;
94- match & config {
95- Ok ( config) => {
96- log:: debug!( "Setting up OIDC with issuer: {}" , config. issuer_url) ;
97- }
98- Err ( Some ( err) ) => {
99- log:: error!( "Invalid OIDC configuration: {err}" ) ;
100- }
101- Err ( None ) => {
102- log:: debug!( "No OIDC configuration provided, skipping middleware." ) ;
103- }
104- }
105- let config = config. ok ( ) . map ( Arc :: new) ;
106- Condition :: new (
107- config. is_some ( ) ,
108- Self {
109- config,
110- app_state : web:: Data :: clone ( app_state) ,
111- } ,
112- )
123+ let oidc_state = app_state. oidc_state . clone ( ) ;
124+ Condition :: new ( oidc_state. is_some ( ) , Self { oidc_state } )
113125 }
114126}
115127
116128async fn discover_provider_metadata (
117- http_client : & AwcHttpClient ,
129+ http_client : & awc :: Client ,
118130 issuer_url : IssuerUrl ,
119131) -> anyhow:: Result < openidconnect:: core:: CoreProviderMetadata > {
120132 log:: debug!( "Discovering provider metadata for {issuer_url}" ) ;
121- let provider_metadata =
122- openidconnect:: core:: CoreProviderMetadata :: discover_async ( issuer_url, http_client)
123- . await
124- . with_context ( || "Failed to discover OIDC provider metadata" . to_string ( ) ) ?;
133+ let provider_metadata = openidconnect:: core:: CoreProviderMetadata :: discover_async (
134+ issuer_url,
135+ & AwcHttpClient :: from_client ( http_client) ,
136+ )
137+ . await
138+ . with_context ( || "Failed to discover OIDC provider metadata" . to_string ( ) ) ?;
125139 log:: debug!( "Provider metadata discovered: {provider_metadata:?}" ) ;
126140 Ok ( provider_metadata)
127141}
@@ -135,52 +149,28 @@ where
135149 type Error = Error ;
136150 type InitError = ( ) ;
137151 type Transform = OidcService < S > ;
138- type Future = Pin < Box < dyn Future < Output = Result < Self :: Transform , Self :: InitError > > + ' static > > ;
152+ type Future = std :: future :: Ready < Result < Self :: Transform , Self :: InitError > > ;
139153
140154 fn new_transform ( & self , service : S ) -> Self :: Future {
141- let config = self . config . clone ( ) ;
142- let app_state = web:: Data :: clone ( & self . app_state ) ;
143- Box :: pin ( async move {
144- match config {
145- Some ( config) => Ok ( OidcService :: new ( service, & app_state, Arc :: clone ( & config) )
146- . await
147- . map_err ( |err| {
148- log:: error!(
149- "Error creating OIDC service with issuer: {}: {err:?}" ,
150- config. issuer_url
151- ) ;
152- } ) ?) ,
153- None => Err ( ( ) ) ,
154- }
155- } )
155+ match & self . oidc_state {
156+ Some ( state) => ready ( Ok ( OidcService :: new ( service, Arc :: clone ( state) ) ) ) ,
157+ None => ready ( Err ( ( ) ) ) ,
158+ }
156159 }
157160}
158161
159162#[ derive( Clone ) ]
160163pub struct OidcService < S > {
161164 service : S ,
162- config : Arc < OidcConfig > ,
163- oidc_client : Arc < OidcClient > ,
164- http_client : Rc < AwcHttpClient > ,
165+ oidc_state : Arc < OidcState > ,
165166}
166167
167168impl < S > OidcService < S > {
168- pub async fn new (
169- service : S ,
170- app_state : & web:: Data < AppState > ,
171- config : Arc < OidcConfig > ,
172- ) -> anyhow:: Result < Self > {
173- let issuer_url = config. issuer_url . clone ( ) ;
174- let http_client = AwcHttpClient :: new ( & app_state. config ) ?;
175- let provider_metadata = discover_provider_metadata ( & http_client, issuer_url) . await ?;
176- let client: OidcClient = make_oidc_client ( & config, provider_metadata)
177- . with_context ( || format ! ( "Unable to create OIDC client with config: {config:?}" ) ) ?;
178- Ok ( Self {
169+ pub fn new ( service : S , oidc_state : Arc < OidcState > ) -> Self {
170+ Self {
179171 service,
180- config,
181- oidc_client : Arc :: new ( client) ,
182- http_client : Rc :: new ( http_client) ,
183- } )
172+ oidc_state,
173+ }
184174 }
185175
186176 fn handle_unauthenticated_request (
@@ -195,22 +185,24 @@ impl<S> OidcService<S> {
195185
196186 log:: debug!( "Redirecting to OIDC provider" ) ;
197187
198- let response =
199- build_auth_provider_redirect_response ( & self . oidc_client , & self . config , & request) ;
188+ let response = build_auth_provider_redirect_response (
189+ & self . oidc_state . client ,
190+ & self . oidc_state . config ,
191+ & request,
192+ ) ;
200193 Box :: pin ( async move { Ok ( request. into_response ( response) ) } )
201194 }
202195
203196 fn handle_oidc_callback (
204197 & self ,
205198 request : ServiceRequest ,
206199 ) -> LocalBoxFuture < Result < ServiceResponse < BoxBody > , Error > > {
207- let oidc_client = Arc :: clone ( & self . oidc_client ) ;
208- let http_client = Rc :: clone ( & self . http_client ) ;
209- let oidc_config = Arc :: clone ( & self . config ) ;
200+ let oidc_client = Arc :: clone ( & self . oidc_state . client ) ;
201+ let oidc_config = Arc :: clone ( & self . oidc_state . config ) ;
210202
211203 Box :: pin ( async move {
212204 let query_string = request. query_string ( ) ;
213- match process_oidc_callback ( & oidc_client, & http_client , query_string, & request) . await {
205+ match process_oidc_callback ( & oidc_client, query_string, & request) . await {
214206 Ok ( response) => Ok ( request. into_response ( response) ) ,
215207 Err ( e) => {
216208 log:: error!( "Failed to process OIDC callback with params {query_string}: {e}" ) ;
@@ -223,9 +215,6 @@ impl<S> OidcService<S> {
223215 }
224216}
225217
226- type LocalBoxFuture < T > = Pin < Box < dyn Future < Output = T > + ' static > > ;
227- use actix_web:: body:: BoxBody ;
228-
229218impl < S > Service < ServiceRequest > for OidcService < S >
230219where
231220 S : Service < ServiceRequest , Response = ServiceResponse < BoxBody > , Error = Error > ,
@@ -238,8 +227,11 @@ where
238227 forward_ready ! ( service) ;
239228
240229 fn call ( & self , request : ServiceRequest ) -> Self :: Future {
241- log:: debug!( "Started OIDC middleware with config: {:?}" , self . config) ;
242- let oidc_client = Arc :: clone ( & self . oidc_client ) ;
230+ log:: debug!(
231+ "Started OIDC middleware with config: {:?}" ,
232+ self . oidc_state. config
233+ ) ;
234+ let oidc_client = Arc :: clone ( & self . oidc_state . client ) ;
243235 match get_sqlpage_auth_cookie ( & oidc_client, & request) {
244236 Ok ( Some ( cookie) ) => {
245237 log:: trace!( "Found SQLPage auth cookie: {cookie}" ) ;
@@ -269,10 +261,11 @@ where
269261
270262async fn process_oidc_callback (
271263 oidc_client : & OidcClient ,
272- http_client : & AwcHttpClient ,
273264 query_string : & str ,
274265 request : & ServiceRequest ,
275266) -> anyhow:: Result < HttpResponse > {
267+ let http_client = get_http_client_from_appdata ( request) ?;
268+
276269 let state = get_state_from_cookie ( request) ?;
277270
278271 let params = Query :: < OidcCallbackParams > :: from_query ( query_string)
@@ -299,15 +292,14 @@ async fn process_oidc_callback(
299292
300293async fn exchange_code_for_token (
301294 oidc_client : & OidcClient ,
302- http_client : & AwcHttpClient ,
295+ http_client : & awc :: Client ,
303296 oidc_callback_params : OidcCallbackParams ,
304297) -> anyhow:: Result < openidconnect:: core:: CoreTokenResponse > {
305- // TODO: Verify the state matches the expected CSRF token
306298 let token_response = oidc_client
307299 . exchange_code ( openidconnect:: AuthorizationCode :: new (
308300 oidc_callback_params. code ,
309301 ) ) ?
310- . request_async ( http_client)
302+ . request_async ( & AwcHttpClient :: from_client ( http_client) )
311303 . await ?;
312304 Ok ( token_response)
313305}
@@ -376,19 +368,18 @@ fn get_sqlpage_auth_cookie(
376368 Ok ( Some ( cookie_value) )
377369}
378370
379- pub struct AwcHttpClient {
380- client : Client ,
371+ pub struct AwcHttpClient < ' c > {
372+ client : & ' c awc :: Client ,
381373}
382374
383- impl AwcHttpClient {
384- pub fn new ( app_config : & AppConfig ) -> anyhow:: Result < Self > {
385- Ok ( Self {
386- client : make_http_client ( app_config) ?,
387- } )
375+ impl < ' c > AwcHttpClient < ' c > {
376+ #[ must_use]
377+ pub fn from_client ( client : & ' c awc:: Client ) -> Self {
378+ Self { client }
388379 }
389380}
390381
391- impl < ' c > AsyncHttpClient < ' c > for AwcHttpClient {
382+ impl < ' c > AsyncHttpClient < ' c > for AwcHttpClient < ' c > {
392383 type Error = AwcWrapperError ;
393384 type Future =
394385 Pin < Box < dyn Future < Output = Result < openidconnect:: HttpResponse , Self :: Error > > + ' c > > ;
0 commit comments