Skip to content

Commit d833912

Browse files
authored
Reduce sync/async test code duplication (#351)
1 parent 18a8e5c commit d833912

File tree

7 files changed

+309
-955
lines changed

7 files changed

+309
-955
lines changed

tests/conftest.py

Lines changed: 154 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,137 @@
1-
from typing import Any, Callable, Mapping, Optional
1+
from typing import (
2+
Any,
3+
Awaitable,
4+
Callable,
5+
Literal,
6+
Mapping,
7+
Optional,
8+
Sequence,
9+
Tuple,
10+
Union,
11+
cast,
12+
)
213
from unittest.mock import AsyncMock, MagicMock
314

415
import httpx
516
import pytest
617

718
from tests.utils.client_configuration import ClientConfiguration
819
from tests.utils.list_resource import list_data_to_dicts, list_response_of
20+
from tests.utils.syncify import syncify
21+
from tests.types.test_auto_pagination_function import TestAutoPaginationFunction
922
from workos.types.list_resource import WorkOSListResource
1023
from workos.utils._base_http_client import DEFAULT_REQUEST_TIMEOUT
1124
from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient
1225
from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT
1326

1427

15-
@pytest.fixture
16-
def sync_http_client_for_test():
17-
return SyncHTTPClient(
18-
api_key="sk_test",
19-
base_url="https://api.workos.test/",
20-
client_id="client_b27needthisforssotemxo",
21-
version="test",
22-
)
23-
24-
25-
@pytest.fixture
26-
def async_http_client_for_test():
27-
return AsyncHTTPClient(
28-
api_key="sk_test",
29-
base_url="https://api.workos.test/",
30-
client_id="client_b27needthisforssotemxo",
31-
version="test",
32-
)
33-
34-
35-
@pytest.fixture
36-
def sync_client_configuration_and_http_client_for_test():
28+
def _get_test_client_setup(
29+
http_client_class_name: str,
30+
) -> Tuple[Literal["async", "sync"], ClientConfiguration, HTTPClient]:
3731
base_url = "https://api.workos.test/"
3832
client_id = "client_b27needthisforssotemxo"
3933

34+
setup_name = None
35+
if http_client_class_name == "AsyncHTTPClient":
36+
http_client = AsyncHTTPClient(
37+
api_key="sk_test",
38+
base_url=base_url,
39+
client_id=client_id,
40+
version="test",
41+
)
42+
setup_name = "async"
43+
elif http_client_class_name == "SyncHTTPClient":
44+
http_client = SyncHTTPClient(
45+
api_key="sk_test",
46+
base_url=base_url,
47+
client_id=client_id,
48+
version="test",
49+
)
50+
setup_name = "sync"
51+
else:
52+
raise ValueError(
53+
f"Invalid HTTP client for test module setup: {http_client_class_name}"
54+
)
55+
4056
client_configuration = ClientConfiguration(
4157
base_url=base_url, client_id=client_id, request_timeout=DEFAULT_REQUEST_TIMEOUT
4258
)
4359

44-
http_client = SyncHTTPClient(
45-
api_key="sk_test",
46-
base_url=base_url,
47-
client_id=client_id,
48-
version="test",
60+
return setup_name, client_configuration, http_client
61+
62+
63+
def pytest_configure(config) -> None:
64+
config.addinivalue_line(
65+
"markers",
66+
"sync_and_async(): mark test to run both sync and async module versions",
4967
)
5068

51-
return client_configuration, http_client
69+
70+
def pytest_generate_tests(metafunc: pytest.Metafunc):
71+
for marker in metafunc.definition.iter_markers(name="sync_and_async"):
72+
if marker.name == "sync_and_async":
73+
if len(marker.args) == 0:
74+
raise ValueError(
75+
"sync_and_async marker requires argument representing list of modules."
76+
)
77+
78+
# Take in args as a list of module classes. For example:
79+
# @pytest.mark.sync_and_async(Events, AsyncEvents) -> [Events, AsyncEvents]
80+
module_classes = marker.args
81+
ids = []
82+
arg_values = []
83+
84+
for module_class in module_classes:
85+
if module_class is None:
86+
raise ValueError(
87+
f"Invalid module class for sync_and_async marker: {module_class}"
88+
)
89+
90+
# Pull the HTTP client type from the module class annotations and use that
91+
# to pass in the proper test HTTP client
92+
http_client_name = module_class.__annotations__["_http_client"].__name__
93+
setup_name, client_configuration, http_client = _get_test_client_setup(
94+
http_client_name
95+
)
96+
97+
class_kwargs: Mapping[str, Any] = {"http_client": http_client}
98+
if module_class.__init__.__annotations__.get(
99+
"client_configuration", None
100+
):
101+
class_kwargs["client_configuration"] = client_configuration
102+
103+
module_instance = module_class(**class_kwargs)
104+
105+
ids.append(setup_name) # sync or async will be the test ID
106+
arg_names = ["module_instance"]
107+
arg_values.append([module_instance])
108+
109+
metafunc.parametrize(
110+
argnames=arg_names, argvalues=arg_values, ids=ids, scope="class"
111+
)
52112

53113

54114
@pytest.fixture
55-
def async_client_configuration_and_http_client_for_test():
56-
base_url = "https://api.workos.test/"
57-
client_id = "client_b27needthisforssotemxo"
115+
def sync_http_client_for_test():
116+
_, _, http_client = _get_test_client_setup("SyncHTTPClient")
117+
return http_client
58118

59-
client_configuration = ClientConfiguration(
60-
base_url=base_url, client_id=client_id, request_timeout=DEFAULT_REQUEST_TIMEOUT
61-
)
62119

63-
http_client = AsyncHTTPClient(
64-
api_key="sk_test",
65-
base_url=base_url,
66-
client_id=client_id,
67-
version="test",
68-
)
120+
@pytest.fixture
121+
def async_http_client_for_test():
122+
_, _, http_client = _get_test_client_setup("AsyncHTTPClient")
123+
return http_client
124+
69125

126+
@pytest.fixture
127+
def sync_client_configuration_and_http_client_for_test():
128+
_, client_configuration, http_client = _get_test_client_setup("SyncHTTPClient")
129+
return client_configuration, http_client
130+
131+
132+
@pytest.fixture
133+
def async_client_configuration_and_http_client_for_test():
134+
_, client_configuration, http_client = _get_test_client_setup("AsyncHTTPClient")
70135
return client_configuration, http_client
71136

72137

@@ -177,24 +242,62 @@ def mock_function(*args, **kwargs):
177242

178243

179244
@pytest.fixture
180-
def test_sync_auto_pagination(capture_and_mock_pagination_request_for_http_client):
181-
def inner(
182-
http_client: SyncHTTPClient,
245+
def test_auto_pagination(
246+
capture_and_mock_pagination_request_for_http_client,
247+
) -> TestAutoPaginationFunction:
248+
def _iterate_results_sync(
183249
list_function: Callable[[], WorkOSListResource],
250+
list_function_params: Optional[Mapping[str, Any]] = None,
251+
) -> Sequence[Any]:
252+
results = list_function(**list_function_params or {})
253+
all_results = []
254+
255+
for result in results:
256+
all_results.append(result)
257+
258+
return all_results
259+
260+
async def _iterate_results_async(
261+
list_function: Callable[[], Awaitable[WorkOSListResource]],
262+
list_function_params: Optional[Mapping[str, Any]] = None,
263+
) -> Sequence[Any]:
264+
results = await list_function(**list_function_params or {})
265+
all_results = []
266+
267+
async for result in results:
268+
all_results.append(result)
269+
270+
return all_results
271+
272+
def inner(
273+
http_client: HTTPClient,
274+
list_function: Union[
275+
Callable[[], WorkOSListResource],
276+
Callable[[], Awaitable[WorkOSListResource]],
277+
],
184278
expected_all_page_data: dict,
185279
list_function_params: Optional[Mapping[str, Any]] = None,
186-
):
280+
url_path_keys: Optional[Sequence[str]] = None,
281+
) -> None:
187282
request_kwargs = capture_and_mock_pagination_request_for_http_client(
188283
http_client=http_client,
189284
data_list=expected_all_page_data,
190285
status_code=200,
191286
)
192287

193-
results = list_function(**list_function_params or {})
194288
all_results = []
195-
196-
for result in results:
197-
all_results.append(result)
289+
if isinstance(http_client, AsyncHTTPClient):
290+
all_results = syncify(
291+
_iterate_results_async(
292+
cast(Callable[[], Awaitable[WorkOSListResource]], list_function),
293+
list_function_params,
294+
)
295+
)
296+
else:
297+
all_results = _iterate_results_sync(
298+
cast(Callable[[], WorkOSListResource], list_function),
299+
list_function_params,
300+
)
198301

199302
assert len(list(all_results)) == len(expected_all_page_data)
200303
assert (list_data_to_dicts(all_results)) == expected_all_page_data
@@ -207,6 +310,7 @@ def inner(
207310

208311
params = list_function_params or {}
209312
for param in params:
210-
assert request_kwargs["params"][param] == params[param]
313+
if url_path_keys is not None and param not in url_path_keys:
314+
assert request_kwargs["params"][param] == params[param]
211315

212316
return inner

tests/test_directory_sync.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import pytest
22

3+
from tests.types.test_auto_pagination_function import TestAutoPaginationFunction
34
from tests.utils.list_resource import list_data_to_dicts, list_response_of
4-
from workos.directory_sync import AsyncDirectorySync, DirectorySync
55
from tests.utils.fixtures.mock_directory import MockDirectory
66
from tests.utils.fixtures.mock_directory_user import MockDirectoryUser
77
from tests.utils.fixtures.mock_directory_group import MockDirectoryGroup
8+
from workos.directory_sync import AsyncDirectorySync, DirectorySync
89

910

1011
def api_directory_to_sdk(directory):
@@ -295,27 +296,33 @@ def test_primary_email_none(
295296
assert me == None
296297

297298
def test_list_directories_auto_pagination(
298-
self, mock_directories_multiple_data_pages, test_sync_auto_pagination
299+
self,
300+
mock_directories_multiple_data_pages,
301+
test_auto_pagination: TestAutoPaginationFunction,
299302
):
300-
test_sync_auto_pagination(
303+
test_auto_pagination(
301304
http_client=self.http_client,
302305
list_function=self.directory_sync.list_directories,
303306
expected_all_page_data=mock_directories_multiple_data_pages,
304307
)
305308

306309
def test_directory_users_auto_pagination(
307-
self, mock_directory_users_multiple_data_pages, test_sync_auto_pagination
310+
self,
311+
mock_directory_users_multiple_data_pages,
312+
test_auto_pagination: TestAutoPaginationFunction,
308313
):
309-
test_sync_auto_pagination(
314+
test_auto_pagination(
310315
http_client=self.http_client,
311316
list_function=self.directory_sync.list_users,
312317
expected_all_page_data=mock_directory_users_multiple_data_pages,
313318
)
314319

315320
def test_directory_user_groups_auto_pagination(
316-
self, mock_directory_groups_multiple_data_pages, test_sync_auto_pagination
321+
self,
322+
mock_directory_groups_multiple_data_pages,
323+
test_auto_pagination: TestAutoPaginationFunction,
317324
):
318-
test_sync_auto_pagination(
325+
test_auto_pagination(
319326
http_client=self.http_client,
320327
list_function=self.directory_sync.list_groups,
321328
expected_all_page_data=mock_directory_groups_multiple_data_pages,

tests/test_events.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1+
from typing import Union
12
import pytest
23

34
from tests.utils.fixtures.mock_event import MockEvent
4-
from workos.events import AsyncEvents, Events
5+
from tests.utils.syncify import syncify
6+
from workos.events import AsyncEvents, Events, EventsListResource
57

68

9+
@pytest.mark.sync_and_async(Events, AsyncEvents)
710
class TestEvents(object):
8-
@pytest.fixture(autouse=True)
9-
def setup(self, sync_http_client_for_test):
10-
self.http_client = sync_http_client_for_test
11-
self.events = Events(http_client=self.http_client)
12-
1311
@pytest.fixture
1412
def mock_events(self):
1513
events = [MockEvent(id=str(i)).dict() for i in range(10)]
@@ -22,49 +20,22 @@ def mock_events(self):
2220
},
2321
}
2422

25-
def test_list_events(self, mock_events, capture_and_mock_http_client_request):
23+
def test_list_events(
24+
self,
25+
module_instance: Union[Events, AsyncEvents],
26+
mock_events: EventsListResource,
27+
capture_and_mock_http_client_request,
28+
):
2629
request_kwargs = capture_and_mock_http_client_request(
27-
http_client=self.http_client,
30+
http_client=module_instance._http_client,
2831
status_code=200,
2932
response_dict=mock_events,
3033
)
3134

32-
events = self.events.list_events(events=["dsync.activated"])
33-
34-
assert request_kwargs["url"].endswith("/events")
35-
assert request_kwargs["method"] == "get"
36-
assert request_kwargs["params"] == {"events": ["dsync.activated"], "limit": 10}
37-
assert events.dict() == mock_events
38-
39-
40-
@pytest.mark.asyncio
41-
class TestAsyncEvents(object):
42-
@pytest.fixture(autouse=True)
43-
def setup(self, async_http_client_for_test):
44-
self.http_client = async_http_client_for_test
45-
self.events = AsyncEvents(http_client=self.http_client)
46-
47-
@pytest.fixture
48-
def mock_events(self):
49-
events = [MockEvent(id=str(i)).dict() for i in range(10)]
50-
51-
return {
52-
"object": "list",
53-
"data": events,
54-
"list_metadata": {
55-
"after": None,
56-
},
57-
}
58-
59-
async def test_list_events(self, mock_events, capture_and_mock_http_client_request):
60-
request_kwargs = capture_and_mock_http_client_request(
61-
http_client=self.http_client,
62-
status_code=200,
63-
response_dict=mock_events,
35+
events: EventsListResource = syncify(
36+
module_instance.list_events(events=["dsync.activated"])
6437
)
6538

66-
events = await self.events.list_events(events=["dsync.activated"])
67-
6839
assert request_kwargs["url"].endswith("/events")
6940
assert request_kwargs["method"] == "get"
7041
assert request_kwargs["params"] == {"events": ["dsync.activated"], "limit": 10}

0 commit comments

Comments
 (0)