@@ -71,6 +71,7 @@ def __init__(self, item, collection):
7171class MockRequest :
7272 base_url = "http://test-server"
7373 url = "http://test-server/test"
74+ headers = {}
7475 query_params = {}
7576
7677 def __init__ (
@@ -79,11 +80,13 @@ def __init__(
7980 url : str = "XXXX" ,
8081 app : Optional [Any ] = None ,
8182 query_params : Dict [str , Any ] = {"limit" : "10" },
83+ headers : Dict [str , Any ] = {"Content-type" : "application/json" },
8284 ):
8385 self .method = method
8486 self .url = url
8587 self .app = app
8688 self .query_params = query_params
89+ self .headers = headers
8790
8891
8992class TestSettings (AsyncSettings ):
@@ -127,9 +130,7 @@ def test_collection() -> Dict:
127130
128131
129132async def create_collection (txn_client : TransactionsClient , collection : Dict ) -> None :
130- await txn_client .create_collection (
131- api .Collection (** dict (collection )), request = MockRequest , refresh = True
132- )
133+ await txn_client .create_collection (api .Collection (** dict (collection )), request = MockRequest , refresh = True )
133134
134135
135136async def create_item (txn_client : TransactionsClient , item : Dict ) -> None :
@@ -199,18 +200,14 @@ async def app():
199200 settings = AsyncSettings ()
200201
201202 aggregation_extension = AggregationExtension (
202- client = EsAsyncAggregationClient (
203- database = database , session = None , settings = settings
204- )
203+ client = EsAsyncAggregationClient (database = database , session = None , settings = settings )
205204 )
206205 aggregation_extension .POST = EsAggregationExtensionPostRequest
207206 aggregation_extension .GET = EsAggregationExtensionGetRequest
208207
209208 search_extensions = [
210209 TransactionExtension (
211- client = TransactionsClient (
212- database = database , session = None , settings = settings
213- ),
210+ client = TransactionsClient (database = database , session = None , settings = settings ),
214211 settings = settings ,
215212 ),
216213 SortExtension (),
@@ -244,9 +241,7 @@ async def app_client(app):
244241 await create_index_templates ()
245242 await create_collection_index ()
246243
247- async with AsyncClient (
248- transport = ASGITransport (app = app ), base_url = "http://test-server"
249- ) as c :
244+ async with AsyncClient (transport = ASGITransport (app = app ), base_url = "http://test-server" ) as c :
250245 yield c
251246
252247
@@ -255,18 +250,14 @@ async def app_rate_limit():
255250 settings = AsyncSettings ()
256251
257252 aggregation_extension = AggregationExtension (
258- client = EsAsyncAggregationClient (
259- database = database , session = None , settings = settings
260- )
253+ client = EsAsyncAggregationClient (database = database , session = None , settings = settings )
261254 )
262255 aggregation_extension .POST = EsAggregationExtensionPostRequest
263256 aggregation_extension .GET = EsAggregationExtensionGetRequest
264257
265258 search_extensions = [
266259 TransactionExtension (
267- client = TransactionsClient (
268- database = database , session = None , settings = settings
269- ),
260+ client = TransactionsClient (database = database , session = None , settings = settings ),
270261 settings = settings ,
271262 ),
272263 SortExtension (),
@@ -305,9 +296,7 @@ async def app_client_rate_limit(app_rate_limit):
305296 await create_index_templates ()
306297 await create_collection_index ()
307298
308- async with AsyncClient (
309- transport = ASGITransport (app = app_rate_limit ), base_url = "http://test-server"
310- ) as c :
299+ async with AsyncClient (transport = ASGITransport (app = app_rate_limit ), base_url = "http://test-server" ) as c :
311300 yield c
312301
313302
@@ -349,18 +338,14 @@ async def app_basic_auth():
349338 settings = AsyncSettings ()
350339
351340 aggregation_extension = AggregationExtension (
352- client = EsAsyncAggregationClient (
353- database = database , session = None , settings = settings
354- )
341+ client = EsAsyncAggregationClient (database = database , session = None , settings = settings )
355342 )
356343 aggregation_extension .POST = EsAggregationExtensionPostRequest
357344 aggregation_extension .GET = EsAggregationExtensionGetRequest
358345
359346 search_extensions = [
360347 TransactionExtension (
361- client = TransactionsClient (
362- database = database , session = None , settings = settings
363- ),
348+ client = TransactionsClient (database = database , session = None , settings = settings ),
364349 settings = settings ,
365350 ),
366351 SortExtension (),
@@ -397,9 +382,7 @@ async def app_client_basic_auth(app_basic_auth):
397382 await create_index_templates ()
398383 await create_collection_index ()
399384
400- async with AsyncClient (
401- transport = ASGITransport (app = app_basic_auth ), base_url = "http://test-server"
402- ) as c :
385+ async with AsyncClient (transport = ASGITransport (app = app_basic_auth ), base_url = "http://test-server" ) as c :
403386 yield c
404387
405388
@@ -440,9 +423,7 @@ async def route_dependencies_app():
440423 settings = AsyncSettings ()
441424 extensions = [
442425 TransactionExtension (
443- client = TransactionsClient (
444- database = database , session = None , settings = settings
445- ),
426+ client = TransactionsClient (database = database , session = None , settings = settings ),
446427 settings = settings ,
447428 ),
448429 SortExtension (),
@@ -483,14 +464,10 @@ async def route_dependencies_client(route_dependencies_app):
483464
484465
485466def build_test_app ():
486- TRANSACTIONS_EXTENSIONS = get_bool_env (
487- "ENABLE_TRANSACTIONS_EXTENSIONS" , default = True
488- )
467+ TRANSACTIONS_EXTENSIONS = get_bool_env ("ENABLE_TRANSACTIONS_EXTENSIONS" , default = True )
489468 settings = AsyncSettings ()
490469 aggregation_extension = AggregationExtension (
491- client = EsAsyncAggregationClient (
492- database = database , session = None , settings = settings
493- )
470+ client = EsAsyncAggregationClient (database = database , session = None , settings = settings )
494471 )
495472 aggregation_extension .POST = EsAggregationExtensionPostRequest
496473 aggregation_extension .GET = EsAggregationExtensionGetRequest
@@ -506,9 +483,7 @@ def build_test_app():
506483 search_extensions .insert (
507484 0 ,
508485 TransactionExtension (
509- client = TransactionsClient (
510- database = database , session = None , settings = settings
511- ),
486+ client = TransactionsClient (database = database , session = None , settings = settings ),
512487 settings = settings ,
513488 ),
514489 )
0 commit comments