Skip to content

Commit e4d3ea1

Browse files
committed
feat: httpx.AsyncClient factory method to customize client
Argument `get_httpx_client` incorporated to `SSOBase` to allow customization of `httpx.AsyncClient` used to call auth provider
1 parent 7510234 commit e4d3ea1

File tree

4 files changed

+45
-44
lines changed

4 files changed

+45
-44
lines changed

examples/seznam.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ async def auth_init():
3131
async def auth_callback(request: Request):
3232
"""Verify login"""
3333
with sso:
34-
user = await sso.verify_and_process(request, params={"client_secret": CLIENT_SECRET}) # <- "client_secret" parameter is needed!
34+
user = await sso.verify_and_process(
35+
request, params={"client_secret": CLIENT_SECRET}
36+
) # <- "client_secret" parameter is needed!
3537
return user
3638

3739

fastapi_sso/sso/base.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import warnings
99
from types import TracebackType
10-
from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, TypeVar, Union, overload
10+
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, TypeVar, Union, overload
1111

1212
import httpx
1313
import pydantic
@@ -110,12 +110,14 @@ def __init__(
110110
allow_insecure_http: bool = False,
111111
use_state: bool = False,
112112
scope: Optional[List[str]] = None,
113+
get_async_client: Optional[Callable[[], httpx.AsyncClient]] = None,
113114
):
114115
"""Base class (mixin) for all SSO providers."""
115116
self.client_id: str = client_id
116117
self.client_secret: str = client_secret
117118
self.redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = redirect_uri
118119
self.allow_insecure_http: bool = allow_insecure_http
120+
self.get_async_client: Callable[[], httpx.AsyncClient] = get_async_client or httpx.AsyncClient
119121
self._login_lock = asyncio.Lock()
120122
self._in_stack = False
121123
self._oauth_client: Optional[WebApplicationClient] = None
@@ -330,10 +332,10 @@ async def verify_and_process(
330332
self,
331333
request: Request,
332334
*,
333-
params: Optional[Dict[str, Any]] = None,
334-
headers: Optional[Dict[str, Any]] = None,
335-
redirect_uri: Optional[str] = None,
336-
convert_response: Literal[True] = True,
335+
params: Optional[Dict[str, Any]],
336+
headers: Optional[Dict[str, Any]],
337+
redirect_uri: Optional[str],
338+
convert_response: Literal[True],
337339
) -> Optional[OpenID]: ...
338340

339341
@overload
@@ -458,11 +460,11 @@ async def process_login(
458460
code: str,
459461
request: Request,
460462
*,
461-
params: Optional[Dict[str, Any]] = None,
462-
additional_headers: Optional[Dict[str, Any]] = None,
463-
redirect_uri: Optional[str] = None,
464-
pkce_code_verifier: Optional[str] = None,
465-
convert_response: Literal[True] = True,
463+
params: Optional[Dict[str, Any]],
464+
additional_headers: Optional[Dict[str, Any]],
465+
redirect_uri: Optional[str],
466+
pkce_code_verifier: Optional[str],
467+
convert_response: Literal[True],
466468
) -> Optional[OpenID]: ...
467469

468470
@overload
@@ -471,10 +473,10 @@ async def process_login(
471473
code: str,
472474
request: Request,
473475
*,
474-
params: Optional[Dict[str, Any]] = None,
475-
additional_headers: Optional[Dict[str, Any]] = None,
476-
redirect_uri: Optional[str] = None,
477-
pkce_code_verifier: Optional[str] = None,
476+
params: Optional[Dict[str, Any]],
477+
additional_headers: Optional[Dict[str, Any]],
478+
redirect_uri: Optional[str],
479+
pkce_code_verifier: Optional[str],
478480
convert_response: Literal[False],
479481
) -> Optional[Dict[str, Any]]: ...
480482

@@ -552,7 +554,7 @@ async def process_login(
552554

553555
auth = httpx.BasicAuth(self.client_id, self.client_secret)
554556

555-
async with httpx.AsyncClient() as session:
557+
async with self.get_async_client() as session:
556558
response = await session.post(token_url, headers=headers, content=body, auth=auth)
557559
content = response.json()
558560
self._refresh_token = content.get("refresh_token")

tests/test_providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ async def test_process_login(self, Provider: Type[SSOBase], monkeypatch: pytest.
151151
async def fake_openid_from_response(_, __):
152152
return OpenID(id="test", email="[email protected]", display_name="Test")
153153

154-
async with sso:
155-
monkeypatch.setattr("httpx.AsyncClient", FakeAsyncClient)
154+
with sso:
155+
monkeypatch.setattr(sso, "get_async_client", FakeAsyncClient)
156156
monkeypatch.setattr(sso, "openid_from_response", fake_openid_from_response)
157157
request = Request(url="https://localhost?code=code&state=unique")
158158
await sso.process_login("code", request)

tests/test_race_condition.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -54,29 +54,26 @@ async def get(self, *args, **kwargs):
5454
await asyncio.sleep(0)
5555
return Response(token="")
5656

57-
with patch("fastapi_sso.sso.base.httpx") as httpx:
58-
httpx.AsyncClient = AsyncClient
59-
60-
first_response = Response(token="first_token") # noqa: S106
61-
second_response = Response(token="second_token") # noqa: S106
62-
63-
AsyncClient.post_responses = [second_response, first_response] # reversed order because of `pop`
64-
65-
async def process_login():
66-
# this coro will be executed concurrently.
67-
# completely not caring about the params
68-
request = Mock()
69-
request.url = URL("https://url.com?state=state&code=code")
70-
async with provider:
71-
await provider.process_login(
72-
code="code", request=request, params=dict(state="state"), convert_response=False
73-
)
74-
return provider.access_token
75-
76-
# process login concurrently twice
77-
tasks = [process_login(), process_login()]
78-
results = await asyncio.gather(*tasks)
79-
80-
# we would want to get the first and second tokens,
81-
# but we see that the first request actually obtained the second token as well
82-
assert results == [first_response.token, second_response.token]
57+
first_response = Response(token="first_token") # noqa: S106
58+
second_response = Response(token="second_token") # noqa: S106
59+
AsyncClient.post_responses = [second_response, first_response] # reversed order because of `pop`
60+
provider.get_async_client = AsyncClient
61+
62+
async def process_login():
63+
# this coro will be executed concurrently.
64+
# completely not caring about the params
65+
request = Mock()
66+
request.url = URL("https://url.com?state=state&code=code")
67+
async with provider:
68+
await provider.process_login(
69+
code="code", request=request, params=dict(state="state"), convert_response=False
70+
)
71+
return provider.access_token
72+
73+
# process login concurrently twice
74+
tasks = [process_login(), process_login()]
75+
results = await asyncio.gather(*tasks)
76+
77+
# we would want to get the first and second tokens,
78+
# but we see that the first request actually obtained the second token as well
79+
assert results == [first_response.token, second_response.token]

0 commit comments

Comments
 (0)