55import os
66import warnings
77from types import TracebackType
8- from typing import Any , ClassVar , Dict , List , Literal , Optional , Type , TypedDict , Union , overload
8+ from typing import Any , Callable , ClassVar , Dict , List , Literal , Optional , Type , TypedDict , Union , overload
99
1010import httpx
1111import pydantic
2020logger = logging .getLogger (__name__ )
2121
2222
23- class HttpxClientKwargsType (TypedDict , total = False ):
24- """Parameters of :class:`httpx.AsyncClient`"""
25- verify : bool | str
26- """SSL certificates (a.k.a CA bundle) used to verify the identity of
27- requested hosts. Either `True` (default CA bundle), a path to an SSL
28- certificate file, an `ssl.SSLContext`, or `False` (which will disable
29- verification)."""
30- cert : str | tuple [str , str ] | tuple [str , str , str ]
31- """An SSL certificate used by the requested host to authenticate the
32- client. Either a path to an SSL certificate file, or two-tuple of
33- (certificate file, key file), or a three-tuple of (certificate file, key
34- file, password)."""
35- proxy : str
36- """A proxy URL where all the traffic should be routed."""
37- proxies : str
38- """A dictionary mapping HTTP protocols to proxy URLs."""
39- timeout : int
40- """The timeout configuration to use when sending requests."""
41-
42-
4323class DiscoveryDocument (TypedDict ):
4424 """Discovery document."""
4525
@@ -97,12 +77,14 @@ def __init__(
9777 allow_insecure_http : bool = False ,
9878 use_state : bool = False ,
9979 scope : Optional [List [str ]] = None ,
80+ get_async_client : Optional [Callable [[], httpx .AsyncClient ]] = None ,
10081 ):
10182 """Base class (mixin) for all SSO providers."""
10283 self .client_id : str = client_id
10384 self .client_secret : str = client_secret
10485 self .redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = redirect_uri
10586 self .allow_insecure_http : bool = allow_insecure_http
87+ self .get_async_client : Callable [[], httpx .AsyncClient ] = get_async_client or httpx .AsyncClient
10688 self ._oauth_client : Optional [WebApplicationClient ] = None
10789 self ._generated_state : Optional [str ] = None
10890
@@ -315,7 +297,6 @@ async def verify_and_process(
315297 headers : Optional [Dict [str , Any ]],
316298 redirect_uri : Optional [str ],
317299 convert_response : Literal [True ],
318- httpx_client_kwargs : Optional [HttpxClientKwargsType ],
319300 ) -> Optional [OpenID ]: ...
320301
321302 @overload
@@ -327,7 +308,6 @@ async def verify_and_process(
327308 headers : Optional [Dict [str , Any ]],
328309 redirect_uri : Optional [str ],
329310 convert_response : Literal [False ],
330- httpx_client_kwargs : Optional [HttpxClientKwargsType ],
331311 ) -> Optional [Dict [str , Any ]]: ...
332312
333313 async def verify_and_process (
@@ -338,7 +318,6 @@ async def verify_and_process(
338318 headers : Optional [Dict [str , Any ]] = None ,
339319 redirect_uri : Optional [str ] = None ,
340320 convert_response : bool = True ,
341- httpx_client_kwargs : Optional [HttpxClientKwargsType ] = None
342321 ) -> Union [Optional [OpenID ], Optional [Dict [str , Any ]]]:
343322 """Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
344323
@@ -348,7 +327,6 @@ async def verify_and_process(
348327 headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider.
349328 redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
350329 convert_response (bool): If True, userinfo response is converted to OpenID object.
351- httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
352330
353331 Raises:
354332 SSOLoginError: If the 'code' parameter is not found in the callback request.
@@ -383,7 +361,6 @@ async def verify_and_process(
383361 redirect_uri = redirect_uri ,
384362 pkce_code_verifier = pkce_code_verifier ,
385363 convert_response = convert_response ,
386- httpx_client_kwargs = httpx_client_kwargs ,
387364 )
388365
389366 def __enter__ (self ) -> "SSOBase" :
@@ -420,7 +397,6 @@ async def process_login(
420397 redirect_uri : Optional [str ],
421398 pkce_code_verifier : Optional [str ],
422399 convert_response : Literal [True ],
423- httpx_client_kwargs : Optional [HttpxClientKwargsType ],
424400 ) -> Optional [OpenID ]: ...
425401
426402 @overload
@@ -434,7 +410,6 @@ async def process_login(
434410 redirect_uri : Optional [str ],
435411 pkce_code_verifier : Optional [str ],
436412 convert_response : Literal [False ],
437- httpx_client_kwargs : Optional [HttpxClientKwargsType ],
438413 ) -> Optional [Dict [str , Any ]]: ...
439414
440415 @overload
@@ -448,7 +423,6 @@ async def process_login(
448423 redirect_uri : Optional [str ],
449424 pkce_code_verifier : Optional [str ],
450425 convert_response : bool ,
451- httpx_client_kwargs : Optional [HttpxClientKwargsType ],
452426 ) -> Union [Optional [OpenID ], Optional [Dict [str , Any ]]]: ...
453427
454428 async def process_login (
@@ -461,7 +435,6 @@ async def process_login(
461435 redirect_uri : Optional [str ] = None ,
462436 pkce_code_verifier : Optional [str ] = None ,
463437 convert_response : bool = True ,
464- httpx_client_kwargs : Optional [HttpxClientKwargsType ] = None ,
465438 ) -> Union [Optional [OpenID ], Optional [Dict [str , Any ]]]:
466439 """Processes login from the callback endpoint to verify the user and request user info endpoint.
467440 It's a lower-level method, typically, you should use `verify_and_process` instead.
@@ -474,7 +447,6 @@ async def process_login(
474447 redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
475448 pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
476449 convert_response (bool): If True, userinfo response is converted to OpenID object.
477- httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
478450
479451 Raises:
480452 ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
@@ -494,7 +466,6 @@ async def process_login(
494466 ),
495467 ReusedOauthClientWarning ,
496468 )
497- httpx_client_kwargs = httpx_client_kwargs or {}
498469 params = params or {}
499470 params .update (self ._extra_query_params )
500471 additional_headers = additional_headers or {}
@@ -527,7 +498,7 @@ async def process_login(
527498
528499 auth = httpx .BasicAuth (self .client_id , self .client_secret )
529500
530- async with httpx . AsyncClient ( ** httpx_client_kwargs ) as session :
501+ async with self . get_async_client ( ) as session :
531502 response = await session .post (token_url , headers = headers , content = body , auth = auth )
532503 content = response .json ()
533504 self ._refresh_token = content .get ("refresh_token" )
0 commit comments