Skip to content

Commit 93eb352

Browse files
committed
Adding headers to MockRequest
1 parent 324907c commit 93eb352

File tree

1 file changed

+17
-42
lines changed

1 file changed

+17
-42
lines changed

stac_fastapi/tests/conftest.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(self, item, collection):
7171
class MockRequest:
7272
base_url = "http://test-server"
7373
url = "http://test-server/test"
74+
headers = {}
7475
query_params = {}
7576

7677
def __init__(
@@ -79,11 +80,13 @@ def __init__(
7980
url: str = "XXXX",
8081
app: Optional[Any] = None,
8182
query_params: Dict[str, Any] = {"limit": "10"},
83+
headers: Dict[str, Any] = {"Content-type": "application/json"},
8284
):
8385
self.method = method
8486
self.url = url
8587
self.app = app
8688
self.query_params = query_params
89+
self.headers = headers
8790

8891

8992
class TestSettings(AsyncSettings):
@@ -127,9 +130,7 @@ def test_collection() -> Dict:
127130

128131

129132
async def create_collection(txn_client: TransactionsClient, collection: Dict) -> None:
130-
await txn_client.create_collection(
131-
api.Collection(**dict(collection)), request=MockRequest, refresh=True
132-
)
133+
await txn_client.create_collection(api.Collection(**dict(collection)), request=MockRequest, refresh=True)
133134

134135

135136
async def create_item(txn_client: TransactionsClient, item: Dict) -> None:
@@ -199,18 +200,14 @@ async def app():
199200
settings = AsyncSettings()
200201

201202
aggregation_extension = AggregationExtension(
202-
client=EsAsyncAggregationClient(
203-
database=database, session=None, settings=settings
204-
)
203+
client=EsAsyncAggregationClient(database=database, session=None, settings=settings)
205204
)
206205
aggregation_extension.POST = EsAggregationExtensionPostRequest
207206
aggregation_extension.GET = EsAggregationExtensionGetRequest
208207

209208
search_extensions = [
210209
TransactionExtension(
211-
client=TransactionsClient(
212-
database=database, session=None, settings=settings
213-
),
210+
client=TransactionsClient(database=database, session=None, settings=settings),
214211
settings=settings,
215212
),
216213
SortExtension(),
@@ -244,9 +241,7 @@ async def app_client(app):
244241
await create_index_templates()
245242
await create_collection_index()
246243

247-
async with AsyncClient(
248-
transport=ASGITransport(app=app), base_url="http://test-server"
249-
) as c:
244+
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test-server") as c:
250245
yield c
251246

252247

@@ -255,18 +250,14 @@ async def app_rate_limit():
255250
settings = AsyncSettings()
256251

257252
aggregation_extension = AggregationExtension(
258-
client=EsAsyncAggregationClient(
259-
database=database, session=None, settings=settings
260-
)
253+
client=EsAsyncAggregationClient(database=database, session=None, settings=settings)
261254
)
262255
aggregation_extension.POST = EsAggregationExtensionPostRequest
263256
aggregation_extension.GET = EsAggregationExtensionGetRequest
264257

265258
search_extensions = [
266259
TransactionExtension(
267-
client=TransactionsClient(
268-
database=database, session=None, settings=settings
269-
),
260+
client=TransactionsClient(database=database, session=None, settings=settings),
270261
settings=settings,
271262
),
272263
SortExtension(),
@@ -305,9 +296,7 @@ async def app_client_rate_limit(app_rate_limit):
305296
await create_index_templates()
306297
await create_collection_index()
307298

308-
async with AsyncClient(
309-
transport=ASGITransport(app=app_rate_limit), base_url="http://test-server"
310-
) as c:
299+
async with AsyncClient(transport=ASGITransport(app=app_rate_limit), base_url="http://test-server") as c:
311300
yield c
312301

313302

@@ -349,18 +338,14 @@ async def app_basic_auth():
349338
settings = AsyncSettings()
350339

351340
aggregation_extension = AggregationExtension(
352-
client=EsAsyncAggregationClient(
353-
database=database, session=None, settings=settings
354-
)
341+
client=EsAsyncAggregationClient(database=database, session=None, settings=settings)
355342
)
356343
aggregation_extension.POST = EsAggregationExtensionPostRequest
357344
aggregation_extension.GET = EsAggregationExtensionGetRequest
358345

359346
search_extensions = [
360347
TransactionExtension(
361-
client=TransactionsClient(
362-
database=database, session=None, settings=settings
363-
),
348+
client=TransactionsClient(database=database, session=None, settings=settings),
364349
settings=settings,
365350
),
366351
SortExtension(),
@@ -397,9 +382,7 @@ async def app_client_basic_auth(app_basic_auth):
397382
await create_index_templates()
398383
await create_collection_index()
399384

400-
async with AsyncClient(
401-
transport=ASGITransport(app=app_basic_auth), base_url="http://test-server"
402-
) as c:
385+
async with AsyncClient(transport=ASGITransport(app=app_basic_auth), base_url="http://test-server") as c:
403386
yield c
404387

405388

@@ -440,9 +423,7 @@ async def route_dependencies_app():
440423
settings = AsyncSettings()
441424
extensions = [
442425
TransactionExtension(
443-
client=TransactionsClient(
444-
database=database, session=None, settings=settings
445-
),
426+
client=TransactionsClient(database=database, session=None, settings=settings),
446427
settings=settings,
447428
),
448429
SortExtension(),
@@ -483,14 +464,10 @@ async def route_dependencies_client(route_dependencies_app):
483464

484465

485466
def build_test_app():
486-
TRANSACTIONS_EXTENSIONS = get_bool_env(
487-
"ENABLE_TRANSACTIONS_EXTENSIONS", default=True
488-
)
467+
TRANSACTIONS_EXTENSIONS = get_bool_env("ENABLE_TRANSACTIONS_EXTENSIONS", default=True)
489468
settings = AsyncSettings()
490469
aggregation_extension = AggregationExtension(
491-
client=EsAsyncAggregationClient(
492-
database=database, session=None, settings=settings
493-
)
470+
client=EsAsyncAggregationClient(database=database, session=None, settings=settings)
494471
)
495472
aggregation_extension.POST = EsAggregationExtensionPostRequest
496473
aggregation_extension.GET = EsAggregationExtensionGetRequest
@@ -506,9 +483,7 @@ def build_test_app():
506483
search_extensions.insert(
507484
0,
508485
TransactionExtension(
509-
client=TransactionsClient(
510-
database=database, session=None, settings=settings
511-
),
486+
client=TransactionsClient(database=database, session=None, settings=settings),
512487
settings=settings,
513488
),
514489
)

0 commit comments

Comments
 (0)