1- from typing import Any , Callable , Mapping , Optional
1+ from typing import (
2+ Any ,
3+ Awaitable ,
4+ Callable ,
5+ Literal ,
6+ Mapping ,
7+ Optional ,
8+ Sequence ,
9+ Tuple ,
10+ Union ,
11+ cast ,
12+ )
213from unittest .mock import AsyncMock , MagicMock
314
415import httpx
516import pytest
617
718from tests .utils .client_configuration import ClientConfiguration
819from tests .utils .list_resource import list_data_to_dicts , list_response_of
20+ from tests .utils .syncify import syncify
21+ from tests .types .test_auto_pagination_function import TestAutoPaginationFunction
922from workos .types .list_resource import WorkOSListResource
1023from workos .utils ._base_http_client import DEFAULT_REQUEST_TIMEOUT
1124from workos .utils .http_client import AsyncHTTPClient , HTTPClient , SyncHTTPClient
1225from workos .utils .request_helper import DEFAULT_LIST_RESPONSE_LIMIT
1326
1427
15- @pytest .fixture
16- def sync_http_client_for_test ():
17- return SyncHTTPClient (
18- api_key = "sk_test" ,
19- base_url = "https://api.workos.test/" ,
20- client_id = "client_b27needthisforssotemxo" ,
21- version = "test" ,
22- )
23-
24-
25- @pytest .fixture
26- def async_http_client_for_test ():
27- return AsyncHTTPClient (
28- api_key = "sk_test" ,
29- base_url = "https://api.workos.test/" ,
30- client_id = "client_b27needthisforssotemxo" ,
31- version = "test" ,
32- )
33-
34-
35- @pytest .fixture
36- def sync_client_configuration_and_http_client_for_test ():
28+ def _get_test_client_setup (
29+ http_client_class_name : str ,
30+ ) -> Tuple [Literal ["async" , "sync" ], ClientConfiguration , HTTPClient ]:
3731 base_url = "https://api.workos.test/"
3832 client_id = "client_b27needthisforssotemxo"
3933
34+ setup_name = None
35+ if http_client_class_name == "AsyncHTTPClient" :
36+ http_client = AsyncHTTPClient (
37+ api_key = "sk_test" ,
38+ base_url = base_url ,
39+ client_id = client_id ,
40+ version = "test" ,
41+ )
42+ setup_name = "async"
43+ elif http_client_class_name == "SyncHTTPClient" :
44+ http_client = SyncHTTPClient (
45+ api_key = "sk_test" ,
46+ base_url = base_url ,
47+ client_id = client_id ,
48+ version = "test" ,
49+ )
50+ setup_name = "sync"
51+ else :
52+ raise ValueError (
53+ f"Invalid HTTP client for test module setup: { http_client_class_name } "
54+ )
55+
4056 client_configuration = ClientConfiguration (
4157 base_url = base_url , client_id = client_id , request_timeout = DEFAULT_REQUEST_TIMEOUT
4258 )
4359
44- http_client = SyncHTTPClient (
45- api_key = "sk_test" ,
46- base_url = base_url ,
47- client_id = client_id ,
48- version = "test" ,
60+ return setup_name , client_configuration , http_client
61+
62+
63+ def pytest_configure (config ) -> None :
64+ config .addinivalue_line (
65+ "markers" ,
66+ "sync_and_async(): mark test to run both sync and async module versions" ,
4967 )
5068
51- return client_configuration , http_client
69+
70+ def pytest_generate_tests (metafunc : pytest .Metafunc ):
71+ for marker in metafunc .definition .iter_markers (name = "sync_and_async" ):
72+ if marker .name == "sync_and_async" :
73+ if len (marker .args ) == 0 :
74+ raise ValueError (
75+ "sync_and_async marker requires argument representing list of modules."
76+ )
77+
78+ # Take in args as a list of module classes. For example:
79+ # @pytest.mark.sync_and_async(Events, AsyncEvents) -> [Events, AsyncEvents]
80+ module_classes = marker .args
81+ ids = []
82+ arg_values = []
83+
84+ for module_class in module_classes :
85+ if module_class is None :
86+ raise ValueError (
87+ f"Invalid module class for sync_and_async marker: { module_class } "
88+ )
89+
90+ # Pull the HTTP client type from the module class annotations and use that
91+ # to pass in the proper test HTTP client
92+ http_client_name = module_class .__annotations__ ["_http_client" ].__name__
93+ setup_name , client_configuration , http_client = _get_test_client_setup (
94+ http_client_name
95+ )
96+
97+ class_kwargs : Mapping [str , Any ] = {"http_client" : http_client }
98+ if module_class .__init__ .__annotations__ .get (
99+ "client_configuration" , None
100+ ):
101+ class_kwargs ["client_configuration" ] = client_configuration
102+
103+ module_instance = module_class (** class_kwargs )
104+
105+ ids .append (setup_name ) # sync or async will be the test ID
106+ arg_names = ["module_instance" ]
107+ arg_values .append ([module_instance ])
108+
109+ metafunc .parametrize (
110+ argnames = arg_names , argvalues = arg_values , ids = ids , scope = "class"
111+ )
52112
53113
54114@pytest .fixture
55- def async_client_configuration_and_http_client_for_test ():
56- base_url = "https://api.workos.test/"
57- client_id = "client_b27needthisforssotemxo"
115+ def sync_http_client_for_test ():
116+ _ , _ , http_client = _get_test_client_setup ( "SyncHTTPClient" )
117+ return http_client
58118
59- client_configuration = ClientConfiguration (
60- base_url = base_url , client_id = client_id , request_timeout = DEFAULT_REQUEST_TIMEOUT
61- )
62119
63- http_client = AsyncHTTPClient (
64- api_key = "sk_test" ,
65- base_url = base_url ,
66- client_id = client_id ,
67- version = "test" ,
68- )
120+ @pytest .fixture
121+ def async_http_client_for_test ():
122+ _ , _ , http_client = _get_test_client_setup ("AsyncHTTPClient" )
123+ return http_client
124+
69125
126+ @pytest .fixture
127+ def sync_client_configuration_and_http_client_for_test ():
128+ _ , client_configuration , http_client = _get_test_client_setup ("SyncHTTPClient" )
129+ return client_configuration , http_client
130+
131+
132+ @pytest .fixture
133+ def async_client_configuration_and_http_client_for_test ():
134+ _ , client_configuration , http_client = _get_test_client_setup ("AsyncHTTPClient" )
70135 return client_configuration , http_client
71136
72137
@@ -177,24 +242,62 @@ def mock_function(*args, **kwargs):
177242
178243
179244@pytest .fixture
180- def test_sync_auto_pagination (capture_and_mock_pagination_request_for_http_client ):
181- def inner (
182- http_client : SyncHTTPClient ,
245+ def test_auto_pagination (
246+ capture_and_mock_pagination_request_for_http_client ,
247+ ) -> TestAutoPaginationFunction :
248+ def _iterate_results_sync (
183249 list_function : Callable [[], WorkOSListResource ],
250+ list_function_params : Optional [Mapping [str , Any ]] = None ,
251+ ) -> Sequence [Any ]:
252+ results = list_function (** list_function_params or {})
253+ all_results = []
254+
255+ for result in results :
256+ all_results .append (result )
257+
258+ return all_results
259+
260+ async def _iterate_results_async (
261+ list_function : Callable [[], Awaitable [WorkOSListResource ]],
262+ list_function_params : Optional [Mapping [str , Any ]] = None ,
263+ ) -> Sequence [Any ]:
264+ results = await list_function (** list_function_params or {})
265+ all_results = []
266+
267+ async for result in results :
268+ all_results .append (result )
269+
270+ return all_results
271+
272+ def inner (
273+ http_client : HTTPClient ,
274+ list_function : Union [
275+ Callable [[], WorkOSListResource ],
276+ Callable [[], Awaitable [WorkOSListResource ]],
277+ ],
184278 expected_all_page_data : dict ,
185279 list_function_params : Optional [Mapping [str , Any ]] = None ,
186- ):
280+ url_path_keys : Optional [Sequence [str ]] = None ,
281+ ) -> None :
187282 request_kwargs = capture_and_mock_pagination_request_for_http_client (
188283 http_client = http_client ,
189284 data_list = expected_all_page_data ,
190285 status_code = 200 ,
191286 )
192287
193- results = list_function (** list_function_params or {})
194288 all_results = []
195-
196- for result in results :
197- all_results .append (result )
289+ if isinstance (http_client , AsyncHTTPClient ):
290+ all_results = syncify (
291+ _iterate_results_async (
292+ cast (Callable [[], Awaitable [WorkOSListResource ]], list_function ),
293+ list_function_params ,
294+ )
295+ )
296+ else :
297+ all_results = _iterate_results_sync (
298+ cast (Callable [[], WorkOSListResource ], list_function ),
299+ list_function_params ,
300+ )
198301
199302 assert len (list (all_results )) == len (expected_all_page_data )
200303 assert (list_data_to_dicts (all_results )) == expected_all_page_data
@@ -207,6 +310,7 @@ def inner(
207310
208311 params = list_function_params or {}
209312 for param in params :
210- assert request_kwargs ["params" ][param ] == params [param ]
313+ if url_path_keys is not None and param not in url_path_keys :
314+ assert request_kwargs ["params" ][param ] == params [param ]
211315
212316 return inner
0 commit comments