Skip to content

Commit 29c1045

Browse files
committed
httpx.AsyncClient factory method to customize client
Argument `get_httpx_client` incorporated to `SSOBase` to allow customization of `httpx.AsyncClient` used to call auth provider
1 parent a85929e commit 29c1045

File tree

2 files changed

+6
-36
lines changed

2 files changed

+6
-36
lines changed

fastapi_sso/sso/base.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import warnings
77
from types import TracebackType
8-
from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, Union, overload
8+
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, Union, overload
99

1010
import httpx
1111
import pydantic
@@ -20,26 +20,6 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23-
class HttpxClientKwargsType(TypedDict, total=False):
24-
"""Parameters of :class:`httpx.AsyncClient`"""
25-
verify: bool | str
26-
"""SSL certificates (a.k.a CA bundle) used to verify the identity of
27-
requested hosts. Either `True` (default CA bundle), a path to an SSL
28-
certificate file, an `ssl.SSLContext`, or `False` (which will disable
29-
verification)."""
30-
cert: str | tuple[str, str] | tuple[str, str, str]
31-
"""An SSL certificate used by the requested host to authenticate the
32-
client. Either a path to an SSL certificate file, or two-tuple of
33-
(certificate file, key file), or a three-tuple of (certificate file, key
34-
file, password)."""
35-
proxy: str
36-
"""A proxy URL where all the traffic should be routed."""
37-
proxies: str
38-
"""A dictionary mapping HTTP protocols to proxy URLs."""
39-
timeout: int
40-
"""The timeout configuration to use when sending requests."""
41-
42-
4323
class DiscoveryDocument(TypedDict):
4424
"""Discovery document."""
4525

@@ -97,12 +77,14 @@ def __init__(
9777
allow_insecure_http: bool = False,
9878
use_state: bool = False,
9979
scope: Optional[List[str]] = None,
80+
get_async_client: Optional[Callable[[], httpx.AsyncClient]] = None,
10081
):
10182
"""Base class (mixin) for all SSO providers."""
10283
self.client_id: str = client_id
10384
self.client_secret: str = client_secret
10485
self.redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = redirect_uri
10586
self.allow_insecure_http: bool = allow_insecure_http
87+
self.get_async_client: Callable[[], httpx.AsyncClient] = get_async_client or httpx.AsyncClient
10688
self._oauth_client: Optional[WebApplicationClient] = None
10789
self._generated_state: Optional[str] = None
10890

@@ -315,7 +297,6 @@ async def verify_and_process(
315297
headers: Optional[Dict[str, Any]],
316298
redirect_uri: Optional[str],
317299
convert_response: Literal[True],
318-
httpx_client_kwargs: Optional[HttpxClientKwargsType],
319300
) -> Optional[OpenID]: ...
320301

321302
@overload
@@ -327,7 +308,6 @@ async def verify_and_process(
327308
headers: Optional[Dict[str, Any]],
328309
redirect_uri: Optional[str],
329310
convert_response: Literal[False],
330-
httpx_client_kwargs: Optional[HttpxClientKwargsType],
331311
) -> Optional[Dict[str, Any]]: ...
332312

333313
async def verify_and_process(
@@ -338,7 +318,6 @@ async def verify_and_process(
338318
headers: Optional[Dict[str, Any]] = None,
339319
redirect_uri: Optional[str] = None,
340320
convert_response: bool = True,
341-
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None
342321
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
343322
"""Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
344323
@@ -348,7 +327,6 @@ async def verify_and_process(
348327
headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider.
349328
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
350329
convert_response (bool): If True, userinfo response is converted to OpenID object.
351-
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
352330
353331
Raises:
354332
SSOLoginError: If the 'code' parameter is not found in the callback request.
@@ -383,7 +361,6 @@ async def verify_and_process(
383361
redirect_uri=redirect_uri,
384362
pkce_code_verifier=pkce_code_verifier,
385363
convert_response=convert_response,
386-
httpx_client_kwargs=httpx_client_kwargs,
387364
)
388365

389366
def __enter__(self) -> "SSOBase":
@@ -420,7 +397,6 @@ async def process_login(
420397
redirect_uri: Optional[str],
421398
pkce_code_verifier: Optional[str],
422399
convert_response: Literal[True],
423-
httpx_client_kwargs: Optional[HttpxClientKwargsType],
424400
) -> Optional[OpenID]: ...
425401

426402
@overload
@@ -434,7 +410,6 @@ async def process_login(
434410
redirect_uri: Optional[str],
435411
pkce_code_verifier: Optional[str],
436412
convert_response: Literal[False],
437-
httpx_client_kwargs: Optional[HttpxClientKwargsType],
438413
) -> Optional[Dict[str, Any]]: ...
439414

440415
@overload
@@ -448,7 +423,6 @@ async def process_login(
448423
redirect_uri: Optional[str],
449424
pkce_code_verifier: Optional[str],
450425
convert_response: bool,
451-
httpx_client_kwargs: Optional[HttpxClientKwargsType],
452426
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]: ...
453427

454428
async def process_login(
@@ -461,7 +435,6 @@ async def process_login(
461435
redirect_uri: Optional[str] = None,
462436
pkce_code_verifier: Optional[str] = None,
463437
convert_response: bool = True,
464-
httpx_client_kwargs: Optional[HttpxClientKwargsType] = None,
465438
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
466439
"""Processes login from the callback endpoint to verify the user and request user info endpoint.
467440
It's a lower-level method, typically, you should use `verify_and_process` instead.
@@ -474,7 +447,6 @@ async def process_login(
474447
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
475448
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
476449
convert_response (bool): If True, userinfo response is converted to OpenID object.
477-
httpx_client_kwargs (HttpxClientKwargsType): Extra keyword-arguments passed to :class:`httpx.AsyncClient`.
478450
479451
Raises:
480452
ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
@@ -494,7 +466,6 @@ async def process_login(
494466
),
495467
ReusedOauthClientWarning,
496468
)
497-
httpx_client_kwargs = httpx_client_kwargs or {}
498469
params = params or {}
499470
params.update(self._extra_query_params)
500471
additional_headers = additional_headers or {}
@@ -527,7 +498,7 @@ async def process_login(
527498

528499
auth = httpx.BasicAuth(self.client_id, self.client_secret)
529500

530-
async with httpx.AsyncClient(**httpx_client_kwargs) as session:
501+
async with self.get_async_client() as session:
531502
response = await session.post(token_url, headers=headers, content=body, auth=auth)
532503
content = response.json()
533504
self._refresh_token = content.get("refresh_token")

tests/test_providers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ async def test_login_url_scope_additional(self, Provider: Type[SSOBase]):
138138
async def test_process_login(self, Provider: Type[SSOBase], monkeypatch: pytest.MonkeyPatch):
139139
sso = Provider("client_id", "client_secret")
140140
FakeAsyncClient = make_fake_async_client(
141-
returns_post=Response(url="https://localhost", json_content={"access_token": "token"}),
142-
returns_get=Response(
141+
returns_post=Response(url="https://localhost", json_content={"access_token": "token"}), returns_get=Response(
143142
url="https://localhost",
144143
json_content=AnythingDict(
145144
{"token_endpoint": "https://localhost", "userinfo_endpoint": "https://localhost"}
@@ -151,7 +150,7 @@ async def fake_openid_from_response(_, __):
151150
return OpenID(id="test", email="[email protected]", display_name="Test")
152151

153152
with sso:
154-
monkeypatch.setattr("httpx.AsyncClient", FakeAsyncClient)
153+
monkeypatch.setattr(sso, "get_async_client", FakeAsyncClient)
155154
monkeypatch.setattr(sso, "openid_from_response", fake_openid_from_response)
156155
request = Request(url="https://localhost?code=code&state=unique")
157156
await sso.process_login("code", request)

0 commit comments

Comments
 (0)