1313import time
1414from collections .abc import AsyncGenerator , Awaitable , Callable
1515from dataclasses import dataclass , field
16- from typing import Protocol
16+ from typing import Any , Protocol
1717from urllib .parse import urlencode , urljoin , urlparse
1818
1919import anyio
@@ -88,8 +88,8 @@ class OAuthContext:
8888 server_url : str
8989 client_metadata : OAuthClientMetadata
9090 storage : TokenStorage
91- redirect_handler : Callable [[str ], Awaitable [None ]]
92- callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]]
91+ redirect_handler : Callable [[str ], Awaitable [None ]] | None
92+ callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None
9393 timeout : float = 300.0
9494
9595 # Discovered metadata
@@ -189,8 +189,8 @@ def __init__(
189189 server_url : str ,
190190 client_metadata : OAuthClientMetadata ,
191191 storage : TokenStorage ,
192- redirect_handler : Callable [[str ], Awaitable [None ]],
193- callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]],
192+ redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
193+ callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
194194 timeout : float = 300.0 ,
195195 ):
196196 """Initialize OAuth2 authentication."""
@@ -351,8 +351,21 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
351351 except ValidationError as e :
352352 raise OAuthRegistrationError (f"Invalid registration response: { e } " )
353353
354- async def _perform_authorization (self ) -> tuple [str , str ]:
354+ async def _perform_authorization (self ) -> httpx .Request :
355+ """Perform the authorization flow."""
356+ auth_code , code_verifier = await self ._perform_authorization_code_grant ()
357+ token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
358+ return token_request
359+
360+ async def _perform_authorization_code_grant (self ) -> tuple [str , str ]:
355361 """Perform the authorization redirect and get auth code."""
362+ if self .context .client_metadata .redirect_uris is None :
363+ raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
364+ if not self .context .redirect_handler :
365+ raise OAuthFlowError ("No redirect handler provided for authorization code grant" )
366+ if not self .context .callback_handler :
367+ raise OAuthFlowError ("No callback handler provided for authorization code grant" )
368+
356369 if self .context .oauth_metadata and self .context .oauth_metadata .authorization_endpoint :
357370 auth_endpoint = str (self .context .oauth_metadata .authorization_endpoint )
358371 else :
@@ -397,24 +410,34 @@ async def _perform_authorization(self) -> tuple[str, str]:
397410 # Return auth code and code verifier for token exchange
398411 return auth_code , pkce_params .code_verifier
399412
400- async def _exchange_token (self , auth_code : str , code_verifier : str ) -> httpx .Request :
401- """Build token exchange request."""
402- if not self .context .client_info :
403- raise OAuthFlowError ("Missing client info" )
404-
413+ def _get_token_endpoint (self ) -> str :
405414 if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
406415 token_url = str (self .context .oauth_metadata .token_endpoint )
407416 else :
408417 auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
409418 token_url = urljoin (auth_base_url , "/token" )
419+ return token_url
420+
421+ async def _exchange_token_authorization_code (
422+ self , auth_code : str , code_verifier : str , * , token_data : dict [str , Any ] | None = {}
423+ ) -> httpx .Request :
424+ """Build token exchange request for authorization_code flow."""
425+ if self .context .client_metadata .redirect_uris is None :
426+ raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
427+ if not self .context .client_info :
428+ raise OAuthFlowError ("Missing client info" )
410429
411- token_data = {
412- "grant_type" : "authorization_code" ,
413- "code" : auth_code ,
414- "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
415- "client_id" : self .context .client_info .client_id ,
416- "code_verifier" : code_verifier ,
417- }
430+ token_url = self ._get_token_endpoint ()
431+ token_data = token_data or {}
432+ token_data .update (
433+ {
434+ "grant_type" : "authorization_code" ,
435+ "code" : auth_code ,
436+ "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
437+ "client_id" : self .context .client_info .client_id ,
438+ "code_verifier" : code_verifier ,
439+ }
440+ )
418441
419442 # Only include resource param if conditions are met
420443 if self .context .should_include_resource_param (self .context .protocol_version ):
@@ -430,7 +453,9 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
430453 async def _handle_token_response (self , response : httpx .Response ) -> None :
431454 """Handle token exchange response."""
432455 if response .status_code != 200 :
433- raise OAuthTokenError (f"Token exchange failed: { response .status_code } " )
456+ body = await response .aread ()
457+ body = body .decode ("utf-8" )
458+ raise OAuthTokenError (f"Token exchange failed ({ response .status_code } ): { body } " )
434459
435460 try :
436461 content = await response .aread ()
@@ -577,12 +602,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
577602 registration_response = yield registration_request
578603 await self ._handle_registration_response (registration_response )
579604
580- # Step 5: Perform authorization
581- auth_code , code_verifier = await self ._perform_authorization ()
582-
583- # Step 6: Exchange authorization code for tokens
584- token_request = await self ._exchange_token (auth_code , code_verifier )
585- token_response = yield token_request
605+ # Step 5: Perform authorization and complete token exchange
606+ token_response = yield await self ._perform_authorization ()
586607 await self ._handle_token_response (token_response )
587608 except Exception :
588609 logger .exception ("OAuth flow error" )
@@ -601,17 +622,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
601622 # Step 2a: Update the required scopes
602623 self ._select_scopes (response )
603624
604- # Step 2b: Perform (re-)authorization
605- auth_code , code_verifier = await self ._perform_authorization ()
606-
607- # Step 2c: Exchange authorization code for tokens
608- token_request = await self ._exchange_token (auth_code , code_verifier )
609- token_response = yield token_request
625+ # Step 2b: Perform (re-)authorization and token exchange
626+ token_response = yield await self ._perform_authorization ()
610627 await self ._handle_token_response (token_response )
611628 except Exception :
612629 logger .exception ("OAuth flow error" )
613630 raise
614631
615- # Retry with new tokens
616- self ._add_auth_header (request )
617- yield request
632+ # Retry with new tokens
633+ self ._add_auth_header (request )
634+ yield request
0 commit comments