77import sys
88import warnings
99from types import TracebackType
10- from typing import Any , ClassVar , Dict , List , Literal , Optional , Type , TypedDict , TypeVar , Union , overload
10+ from typing import Any , Callable , ClassVar , Dict , List , Literal , Optional , Type , TypedDict , TypeVar , Union , overload
1111
1212import httpx
1313import pydantic
@@ -110,12 +110,14 @@ def __init__(
110110 allow_insecure_http : bool = False ,
111111 use_state : bool = False ,
112112 scope : Optional [List [str ]] = None ,
113+ get_async_client : Optional [Callable [[], httpx .AsyncClient ]] = None ,
113114 ):
114115 """Base class (mixin) for all SSO providers."""
115116 self .client_id : str = client_id
116117 self .client_secret : str = client_secret
117118 self .redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = redirect_uri
118119 self .allow_insecure_http : bool = allow_insecure_http
120+ self .get_async_client : Callable [[], httpx .AsyncClient ] = get_async_client or httpx .AsyncClient
119121 self ._login_lock = asyncio .Lock ()
120122 self ._in_stack = False
121123 self ._oauth_client : Optional [WebApplicationClient ] = None
@@ -330,10 +332,10 @@ async def verify_and_process(
330332 self ,
331333 request : Request ,
332334 * ,
333- params : Optional [Dict [str , Any ]] = None ,
334- headers : Optional [Dict [str , Any ]] = None ,
335- redirect_uri : Optional [str ] = None ,
336- convert_response : Literal [True ] = True ,
335+ params : Optional [Dict [str , Any ]],
336+ headers : Optional [Dict [str , Any ]],
337+ redirect_uri : Optional [str ],
338+ convert_response : Literal [True ],
337339 ) -> Optional [OpenID ]: ...
338340
339341 @overload
@@ -458,11 +460,11 @@ async def process_login(
458460 code : str ,
459461 request : Request ,
460462 * ,
461- params : Optional [Dict [str , Any ]] = None ,
462- additional_headers : Optional [Dict [str , Any ]] = None ,
463- redirect_uri : Optional [str ] = None ,
464- pkce_code_verifier : Optional [str ] = None ,
465- convert_response : Literal [True ] = True ,
463+ params : Optional [Dict [str , Any ]],
464+ additional_headers : Optional [Dict [str , Any ]],
465+ redirect_uri : Optional [str ],
466+ pkce_code_verifier : Optional [str ],
467+ convert_response : Literal [True ],
466468 ) -> Optional [OpenID ]: ...
467469
468470 @overload
@@ -471,10 +473,10 @@ async def process_login(
471473 code : str ,
472474 request : Request ,
473475 * ,
474- params : Optional [Dict [str , Any ]] = None ,
475- additional_headers : Optional [Dict [str , Any ]] = None ,
476- redirect_uri : Optional [str ] = None ,
477- pkce_code_verifier : Optional [str ] = None ,
476+ params : Optional [Dict [str , Any ]],
477+ additional_headers : Optional [Dict [str , Any ]],
478+ redirect_uri : Optional [str ],
479+ pkce_code_verifier : Optional [str ],
478480 convert_response : Literal [False ],
479481 ) -> Optional [Dict [str , Any ]]: ...
480482
@@ -552,7 +554,7 @@ async def process_login(
552554
553555 auth = httpx .BasicAuth (self .client_id , self .client_secret )
554556
555- async with httpx . AsyncClient () as session :
557+ async with self . get_async_client () as session :
556558 response = await session .post (token_url , headers = headers , content = body , auth = auth )
557559 content = response .json ()
558560 self ._refresh_token = content .get ("refresh_token" )
0 commit comments