Skip to content

Commit 495f870

Browse files
committed
Added httpx_client_kwargs parameter to allow customization of httpx.AsyncClient behaviour
1 parent 4a17c26 commit 495f870

File tree

1 file changed

+64
-20
lines changed

1 file changed

+64
-20
lines changed

fastapi_sso/sso/base.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,26 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23+
class HttpxClientKwargsType(TypedDict, total=False):
24+
"""Parameters of :class:`httpx.AsyncClient`"""
25+
verify: bool | str
26+
"""SSL certificates (a.k.a CA bundle) used to verify the identity of
27+
requested hosts. Either `True` (default CA bundle), a path to an SSL
28+
certificate file, an `ssl.SSLContext`, or `False` (which will disable
29+
verification)."""
30+
cert: str | tuple[str, str] | tuple[str, str, str]
31+
"""An SSL certificate used by the requested host to authenticate the
32+
client. Either a path to an SSL certificate file, or two-tuple of
33+
(certificate file, key file), or a three-tuple of (certificate file, key
34+
file, password)."""
35+
proxy: str
36+
"""A proxy URL where all the traffic should be routed."""
37+
proxies: str
38+
"""A dictionary mapping HTTP protocols to proxy URLs."""
39+
timeout: int
40+
"""The timeout configuration to use when sending requests."""
41+
42+
2343
class DiscoveryDocument(TypedDict):
2444
"""Discovery document."""
2545

@@ -291,21 +311,23 @@ async def verify_and_process(
291311
self,
292312
request: Request,
293313
*,
294-
params: Optional[Dict[str, Any]] = None,
295-
headers: Optional[Dict[str, Any]] = None,
296-
redirect_uri: Optional[str] = None,
297-
convert_response: Literal[True] = True,
314+
params: Optional[Dict[str, Any]],
315+
headers: Optional[Dict[str, Any]],
316+
redirect_uri: Optional[str],
317+
convert_response: Literal[True],
318+
httpx_client_kwargs: Optional[HttpxClientKwargsType],
298319
) -> Optional[OpenID]: ...
299320

300321
@overload
301322
async def verify_and_process(
302323
self,
303324
request: Request,
304325
*,
305-
params: Optional[Dict[str, Any]] = None,
306-
headers: Optional[Dict[str, Any]] = None,
307-
redirect_uri: Optional[str] = None,
326+
params: Optional[Dict[str, Any]],
327+
headers: Optional[Dict[str, Any]],
328+
redirect_uri: Optional[str],
308329
convert_response: Literal[False],
330+
httpx_client_kwargs: Optional[HttpxClientKwargsType],
309331
) -> Optional[Dict[str, Any]]: ...
310332

311333
async def verify_and_process(
@@ -315,7 +337,8 @@ async def verify_and_process(
315337
params: Optional[Dict[str, Any]] = None,
316338
headers: Optional[Dict[str, Any]] = None,
317339
redirect_uri: Optional[str] = None,
318-
convert_response: Union[Literal[True], Literal[False]] = True,
340+
convert_response: bool = True,
341+
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None
319342
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
320343
"""Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
321344
@@ -325,6 +348,7 @@ async def verify_and_process(
325348
headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider.
326349
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
327350
convert_response (bool): If True, userinfo response is converted to OpenID object.
351+
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
328352
329353
Raises:
330354
SSOLoginError: If the 'code' parameter is not found in the callback request.
@@ -334,7 +358,7 @@ async def verify_and_process(
334358
Optional[Dict[str, Any]]: The original JSON response from the API.
335359
"""
336360
headers = headers or {}
337-
code = request.query_params.get("code")
361+
code: Optional[str] = request.query_params.get("code")
338362
if code is None:
339363
logger.debug(
340364
"Callback request:\n\tURI: %s\n\tHeaders: %s\n\tQuery params: %s",
@@ -359,6 +383,7 @@ async def verify_and_process(
359383
redirect_uri=redirect_uri,
360384
pkce_code_verifier=pkce_code_verifier,
361385
convert_response=convert_response,
386+
httpx_client_kwargs=httpx_client_kwargs,
362387
)
363388

364389
def __enter__(self) -> "SSOBase":
@@ -390,11 +415,12 @@ async def process_login(
390415
code: str,
391416
request: Request,
392417
*,
393-
params: Optional[Dict[str, Any]] = None,
394-
additional_headers: Optional[Dict[str, Any]] = None,
395-
redirect_uri: Optional[str] = None,
396-
pkce_code_verifier: Optional[str] = None,
397-
convert_response: Literal[True] = True,
418+
params: Optional[Dict[str, Any]],
419+
additional_headers: Optional[Dict[str, Any]],
420+
redirect_uri: Optional[str],
421+
pkce_code_verifier: Optional[str],
422+
convert_response: Literal[True],
423+
httpx_client_kwargs: Optional[HttpxClientKwargsType],
398424
) -> Optional[OpenID]: ...
399425

400426
@overload
@@ -403,13 +429,28 @@ async def process_login(
403429
code: str,
404430
request: Request,
405431
*,
406-
params: Optional[Dict[str, Any]] = None,
407-
additional_headers: Optional[Dict[str, Any]] = None,
408-
redirect_uri: Optional[str] = None,
409-
pkce_code_verifier: Optional[str] = None,
432+
params: Optional[Dict[str, Any]],
433+
additional_headers: Optional[Dict[str, Any]],
434+
redirect_uri: Optional[str],
435+
pkce_code_verifier: Optional[str],
410436
convert_response: Literal[False],
437+
httpx_client_kwargs: Optional[HttpxClientKwargsType],
411438
) -> Optional[Dict[str, Any]]: ...
412439

440+
@overload
441+
async def process_login(
442+
self,
443+
code: str,
444+
request: Request,
445+
*,
446+
params: Optional[Dict[str, Any]],
447+
additional_headers: Optional[Dict[str, Any]],
448+
redirect_uri: Optional[str],
449+
pkce_code_verifier: Optional[str],
450+
convert_response: bool,
451+
httpx_client_kwargs: Optional[HttpxClientKwargsType],
452+
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: ...
453+
413454
async def process_login(
414455
self,
415456
code: str,
@@ -419,7 +460,8 @@ async def process_login(
419460
additional_headers: Optional[Dict[str, Any]] = None,
420461
redirect_uri: Optional[str] = None,
421462
pkce_code_verifier: Optional[str] = None,
422-
convert_response: Union[Literal[True], Literal[False]] = True,
463+
convert_response: bool = True,
464+
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None,
423465
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
424466
"""Processes login from the callback endpoint to verify the user and request user info endpoint.
425467
It's a lower-level method, typically, you should use `verify_and_process` instead.
@@ -432,6 +474,7 @@ async def process_login(
432474
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
433475
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
434476
convert_response (bool): If True, userinfo response is converted to OpenID object.
477+
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
435478
436479
Raises:
437480
ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
@@ -451,6 +494,7 @@ async def process_login(
451494
),
452495
ReusedOauthClientWarning,
453496
)
497+
httpx_client_kwargs = httpx_client_kwargs or {}
454498
params = params or {}
455499
params.update(self._extra_query_params)
456500
additional_headers = additional_headers or {}
@@ -483,7 +527,7 @@ async def process_login(
483527

484528
auth = httpx.BasicAuth(self.client_id, self.client_secret)
485529

486-
async with httpx.AsyncClient() as session:
530+
async with httpx.AsyncClient(**httpx_client_kwargs) as session:
487531
response = await session.post(token_url, headers=headers, content=body, auth=auth)
488532
content = response.json()
489533
self._refresh_token = content.get("refresh_token")

0 commit comments

Comments
 (0)