@@ -15,12 +15,13 @@ use openidconnect::{
1515 EndpointSet , IdToken , IdTokenClaims , IssuerUrl , Nonce , OAuth2TokenResponse , RedirectUrl , Scope ,
1616 TokenResponse ,
1717} ;
18- use serde:: Deserialize ;
18+ use serde:: { Deserialize , Serialize } ;
1919
2020use super :: http_client:: make_http_client;
2121
2222const SQLPAGE_AUTH_COOKIE_NAME : & str = "sqlpage_auth" ;
2323const SQLPAGE_REDIRECT_URI : & str = "/sqlpage/oidc_callback" ;
24+ const SQLPAGE_STATE_COOKIE_NAME : & str = "sqlpage_oidc_state" ;
2425
2526#[ derive( Clone , Debug ) ]
2627pub struct OidcConfig {
@@ -206,8 +207,19 @@ impl<S> OidcService<S> {
206207 }
207208
208209 log:: debug!( "Redirecting to OIDC provider" ) ;
209- let auth_url = self . build_auth_url ( & request) ;
210- Box :: pin ( async move { Ok ( request. into_response ( build_redirect_response ( auth_url) ) ) } )
210+
211+ let auth_url = build_auth_url (
212+ & self . oidc_client ,
213+ & self . config . scopes ,
214+ request. path ( ) . to_string ( ) ,
215+ ) ;
216+ Box :: pin ( async move {
217+ let state_cookie = create_state_cookie ( & request) ;
218+ let mut response = build_redirect_response ( auth_url) ;
219+
220+ response. add_cookie ( & state_cookie) ?;
221+ Ok ( request. into_response ( response) )
222+ } )
211223 }
212224
213225 fn handle_oidc_callback (
@@ -220,11 +232,15 @@ impl<S> OidcService<S> {
220232
221233 Box :: pin ( async move {
222234 let query_string = request. query_string ( ) ;
223- match process_oidc_callback ( & oidc_client, & http_client, query_string) . await {
235+ match process_oidc_callback ( & oidc_client, & http_client, query_string, & request ) . await {
224236 Ok ( response) => Ok ( request. into_response ( response) ) ,
225237 Err ( e) => {
226238 log:: error!( "Failed to process OIDC callback with params {query_string}: {e}" ) ;
227- let auth_url = build_auth_url ( & oidc_client, & oidc_config. scopes ) ;
239+ let auth_url = build_auth_url (
240+ & oidc_client,
241+ & oidc_config. scopes ,
242+ request. path ( ) . to_string ( ) ,
243+ ) ;
228244 Ok ( request. into_response ( build_redirect_response ( auth_url) ) )
229245 }
230246 }
@@ -274,15 +290,22 @@ async fn process_oidc_callback(
274290 oidc_client : & Arc < OidcClient > ,
275291 http_client : & Arc < AwcHttpClient > ,
276292 query_string : & str ,
293+ request : & ServiceRequest ,
277294) -> anyhow:: Result < HttpResponse > {
295+ let state = get_state_from_cookie ( request) ?;
296+
278297 let params = Query :: < OidcCallbackParams > :: from_query ( query_string)
279- . with_context ( || format ! ( "{SQLPAGE_REDIRECT_URI}: failed to parse OIDC callback parameters from {query_string}" ) ) ?
298+ . with_context ( || {
299+ format ! (
300+ "{SQLPAGE_REDIRECT_URI}: failed to parse OIDC callback parameters from {query_string}"
301+ )
302+ } ) ?
280303 . into_inner ( ) ;
281304 log:: debug!( "Processing OIDC callback with params: {params:?}. Requesting token..." ) ;
282305 let token_response = exchange_code_for_token ( oidc_client, http_client, params) . await ?;
283306 log:: debug!( "Received token response: {token_response:?}" ) ;
284- // TODO: redirect to the original URL instead of /
285- let mut response = build_redirect_response ( format ! ( "/" ) ) ;
307+
308+ let mut response = build_redirect_response ( state . initial_url ) ;
286309 set_auth_cookie ( & mut response, & token_response) ?;
287310 Ok ( response)
288311}
@@ -476,7 +499,7 @@ struct OidcCallbackParams {
476499 state : String ,
477500}
478501
479- fn build_auth_url ( oidc_client : & OidcClient , scopes : & [ Scope ] ) -> String {
502+ fn build_auth_url ( oidc_client : & OidcClient , scopes : & [ Scope ] , initial_url : String ) -> String {
480503 let ( auth_url, csrf_token, nonce) = oidc_client
481504 . authorize_url (
482505 CoreAuthenticationFlow :: AuthorizationCode ,
@@ -487,3 +510,30 @@ fn build_auth_url(oidc_client: &OidcClient, scopes: &[Scope]) -> String {
487510 . url ( ) ;
488511 auth_url. to_string ( )
489512}
513+
514+ #[ derive( Debug , Serialize , Deserialize ) ]
515+ struct OidcLoginState {
516+ #[ serde( rename = "u" ) ]
517+ initial_url : String ,
518+ }
519+
520+ fn create_state_cookie ( request : & ServiceRequest ) -> actix_web:: cookie:: Cookie {
521+ let state = OidcLoginState {
522+ initial_url : request. path ( ) . to_string ( ) ,
523+ } ;
524+ let state_json = serde_json:: to_string ( & state) . unwrap ( ) ;
525+ actix_web:: cookie:: Cookie :: build ( SQLPAGE_STATE_COOKIE_NAME , state_json)
526+ . secure ( true )
527+ . http_only ( true )
528+ . same_site ( actix_web:: cookie:: SameSite :: Lax )
529+ . path ( "/" )
530+ . finish ( )
531+ }
532+
533+ fn get_state_from_cookie ( request : & ServiceRequest ) -> anyhow:: Result < OidcLoginState > {
534+ let state_cookie = request. cookie ( SQLPAGE_STATE_COOKIE_NAME ) . with_context ( || {
535+ format ! ( "No {SQLPAGE_STATE_COOKIE_NAME} cookie found for {SQLPAGE_REDIRECT_URI}" )
536+ } ) ?;
537+ serde_json:: from_str ( state_cookie. value ( ) )
538+ . with_context ( || format ! ( "Failed to parse OIDC state from cookie" ) )
539+ }
0 commit comments