diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 9c993dc..4a1d351 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -65,42 +65,51 @@ async def all_collections( # noqa: C901 """ base_url = get_base_url(request) - # Parse request parameters - base_args = { - "bbox": bbox, - "limit": limit, - "offset": offset, - "query": orjson.loads(unquote_plus(query)) if query else query, - } + next_link: Optional[Dict[str, Any]] = None + prev_link: Optional[Dict[str, Any]] = None + collections_result: Collections + + if self.extension_is_enabled("CollectionSearchExtension"): + base_args = { + "bbox": bbox, + "limit": limit, + "offset": offset, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + + clean_args = clean_search_args( + base_args=base_args, + datetime=datetime, + fields=fields, + sortby=sortby, + filter_query=filter, + filter_lang=filter_lang, + ) - clean_args = clean_search_args( - base_args=base_args, - datetime=datetime, - fields=fields, - sortby=sortby, - filter_query=filter, - filter_lang=filter_lang, - ) + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM collection_search(:req::text::jsonb); + """, + req=json.dumps(clean_args), + ) + collections_result = await conn.fetchval(q, *p) - async with request.app.state.get_connection(request, "r") as conn: - q, p = render( - """ - SELECT * FROM collection_search(:req::text::jsonb); - """, - req=json.dumps(clean_args), - ) - collections_result: Collections = await conn.fetchval(q, *p) + if links := collections_result.get("links"): + for link in links: + if link["rel"] == "next": + next_link = link + elif link["rel"] == "prev": + prev_link = link - next_link: Optional[Dict[str, Any]] = None - prev_link: Optional[Dict[str, Any]] = None - if links := collections_result.get("links"): - next_link = None - prev_link = None - for link in links: - if link["rel"] == "next": - next_link = link - elif link["rel"] == "prev": - prev_link = link + else: + async with request.app.state.get_connection(request, "r") as conn: + cols = await conn.fetchval( + """ + SELECT * FROM all_collections(); + """ + ) + collections_result = {"collections": cols, "links": []} linked_collections: List[Collection] = [] collections = collections_result["collections"] diff --git a/tests/conftest.py b/tests/conftest.py index 3d998c1..632a89d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -268,3 +268,48 @@ async def load_test2_item(app_client, load_test_data, load_test2_collection): ) assert resp.status_code == 201 return Item.model_validate(resp.json()) + + +@pytest.fixture( + scope="session", +) +def api_client_no_ext(database): + api_settings = Settings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + testing=True, + ) + return StacApi( + settings=api_settings, + extensions=[ + TransactionExtension(client=TransactionsClient(), settings=api_settings) + ], + client=CoreCrudClient(), + ) + + +@pytest.fixture(scope="function") +async def app_no_ext(api_client_no_ext): + logger.info("Creating app Fixture") + time.time() + app = api_client_no_ext.app + await connect_to_db(app) + + yield app + + await close_db_connection(app) + + logger.info("Closed Pools.") + + +@pytest.fixture(scope="function") +async def app_client_no_ext(app_no_ext): + logger.info("creating app_client") + async with AsyncClient( + transport=ASGITransport(app=app_no_ext), base_url="http://test" + ) as c: + yield c diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index e808a7c..1a5ad68 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -307,6 +307,55 @@ async def test_get_collections_search( assert len(resp.json()["collections"]) == 2 +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_all_collections_with_pagination(app_client, load_test_data): + data = load_test_data("test_collection.json") + collection_id = data["id"] + for ii in range(0, 12): + data["id"] = collection_id + f"_{ii}" + resp = await app_client.post( + "/collections", + json=data, + ) + assert resp.status_code == 201 + + resp = await app_client.get("/collections") + cols = resp.json()["collections"] + assert len(cols) == 10 + links = resp.json()["links"] + assert len(links) == 3 + assert {"root", "self", "next"} == {link["rel"] for link in links} + + resp = await app_client.get("/collections", params={"limit": 12}) + cols = resp.json()["collections"] + assert len(cols) == 12 + links = resp.json()["links"] + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_all_collections_without_pagination(app_client_no_ext, load_test_data): + data = load_test_data("test_collection.json") + collection_id = data["id"] + for ii in range(0, 12): + data["id"] = collection_id + f"_{ii}" + resp = await app_client_no_ext.post( + "/collections", + json=data, + ) + assert resp.status_code == 201 + + resp = await app_client_no_ext.get("/collections") + cols = resp.json()["collections"] + assert len(cols) == 12 + links = resp.json()["links"] + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + @requires_pgstac_0_9_2 @pytest.mark.asyncio async def test_get_collections_search_pagination(