diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 0ae6f963..5afebf1b 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -15,6 +15,7 @@ jobs: - {python: '3.12', pypgstac: '0.9.*'} - {python: '3.12', pypgstac: '0.8.*'} - {python: '3.11', pypgstac: '0.8.*'} + - {python: '3.10', pypgstac: '0.8.*'} - {python: '3.9', pypgstac: '0.8.*'} - {python: '3.8', pypgstac: '0.8.*'} diff --git a/docker-compose.yml b/docker-compose.yml index 02158a68..07cc7822 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -35,6 +35,8 @@ services: build: context: . dockerfile: Dockerfile.tests + volumes: + - .:/app environment: - ENVIRONMENT=local - DB_MIN_CONN_SIZE=1 @@ -44,7 +46,7 @@ services: database: container_name: stac-db - image: ghcr.io/stac-utils/pgstac:v0.9.1 + image: ghcr.io/stac-utils/pgstac:v0.9.2 environment: - POSTGRES_USER=username - POSTGRES_PASSWORD=password diff --git a/setup.py b/setup.py index 220cbca4..16edf782 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ "orjson", "pydantic", "stac_pydantic==3.1.*", - "stac-fastapi.api~=3.0.2", - "stac-fastapi.extensions~=3.0.2", - "stac-fastapi.types~=3.0.2", + "stac-fastapi.api~=3.0.3", + "stac-fastapi.extensions~=3.0.3", + "stac-fastapi.types~=3.0.3", "asyncpg", "buildpg", "brotli_asgi", diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index a8f52fbb..f64095f4 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -21,6 +21,7 @@ from stac_fastapi.extensions.core import ( FieldsExtension, FilterExtension, + OffsetPaginationExtension, SortExtension, TokenPaginationExtension, TransactionExtension, @@ -58,6 +59,7 @@ "sort": SortExtension(), "fields": FieldsExtension(), "filter": FilterExtension(client=FiltersClient()), + "pagination": OffsetPaginationExtension(), } enabled_extensions = ( diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 0917ab16..ff5ec16b 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -25,6 +25,7 @@ from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.models.links import ( CollectionLinks, + CollectionSearchPagingLinks, ItemCollectionLinks, ItemLinks, PagingLinks, @@ -46,8 +47,8 @@ async def all_collections( # noqa: C901 bbox: Optional[BBox] = None, datetime: Optional[DateTimeType] = None, limit: Optional[int] = None, + offset: Optional[int] = None, query: Optional[str] = None, - token: Optional[str] = None, fields: Optional[List[str]] = None, sortby: Optional[str] = None, filter: Optional[str] = None, @@ -64,38 +65,51 @@ async def all_collections( # noqa: C901 """ base_url = get_base_url(request) - # Parse request parameters - base_args = { - "bbox": bbox, - "limit": limit, - "token": token, - "query": orjson.loads(unquote_plus(query)) if query else query, - } - - clean_args = clean_search_args( - base_args=base_args, - datetime=datetime, - fields=fields, - sortby=sortby, - filter_query=filter, - filter_lang=filter_lang, - ) - - async with request.app.state.get_connection(request, "r") as conn: - q, p = render( - """ - SELECT * FROM collection_search(:req::text::jsonb); - """, - req=json.dumps(clean_args), + next_link: Optional[Dict[str, Any]] = None + prev_link: Optional[Dict[str, Any]] = None + collections_result: Collections + + if self.extension_is_enabled("CollectionSearchExtension"): + base_args = { + "bbox": bbox, + "limit": limit, + "offset": offset, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + + clean_args = clean_search_args( + base_args=base_args, + datetime=datetime, + fields=fields, + sortby=sortby, + filter_query=filter, + filter_lang=filter_lang, ) - collections_result: Collections = await conn.fetchval(q, *p) - next: Optional[str] = None - prev: Optional[str] = None + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM collection_search(:req::text::jsonb); + """, + req=json.dumps(clean_args), + ) + collections_result = await conn.fetchval(q, *p) - if links := collections_result.get("links"): - next = collections_result["links"].pop("next") - prev = collections_result["links"].pop("prev") + if links := collections_result.get("links"): + for link in links: + if link["rel"] == "next": + next_link = link + elif link["rel"] == "prev": + prev_link = link + + else: + async with request.app.state.get_connection(request, "r") as conn: + cols = await conn.fetchval( + """ + SELECT * FROM all_collections(); + """ + ) + collections_result = {"collections": cols, "links": []} linked_collections: List[Collection] = [] collections = collections_result["collections"] @@ -120,10 +134,10 @@ async def all_collections( # noqa: C901 linked_collections.append(coll) - links = await PagingLinks( + links = await CollectionSearchPagingLinks( request=request, - next=next, - prev=prev, + next=next_link, + prev=prev_link, ).get_links() return Collections( diff --git a/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/models/links.py index 0e6d9071..c0ec4455 100644 --- a/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/models/links.py @@ -173,6 +173,55 @@ def link_prev(self) -> Optional[Dict[str, Any]]: return None +@attr.s +class CollectionSearchPagingLinks(BaseLinks): + next: Optional[Dict[str, Any]] = attr.ib(kw_only=True, default=None) + prev: Optional[Dict[str, Any]] = attr.ib(kw_only=True, default=None) + + def link_next(self) -> Optional[Dict[str, Any]]: + """Create link for next page.""" + if self.next is not None: + method = self.request.method + if method == "GET": + # if offset is equal to default value (0), drop it + if self.next["body"].get("offset", -1) == 0: + _ = self.next["body"].pop("offset") + + href = merge_params(self.url, self.next["body"]) + + # if next link is equal to this link, skip it + if href == self.url: + return None + + return { + "rel": Relations.next.value, + "type": MimeTypes.geojson.value, + "method": method, + "href": href, + } + + return None + + def link_prev(self): + if self.prev is not None: + method = self.request.method + if method == "GET": + href = merge_params(self.url, self.prev["body"]) + + # if prev link is equal to this link, skip it + if href == self.url: + return None + + return { + "rel": Relations.previous.value, + "type": MimeTypes.geojson.value, + "method": method, + "href": href, + } + + return None + + @attr.s class CollectionLinksBase(BaseLinks): """Create inferred links specific to collections.""" diff --git a/tests/conftest.py b/tests/conftest.py index e571cae6..4b180147 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from fastapi import APIRouter from fastapi.responses import ORJSONResponse from httpx import ASGITransport, AsyncClient +from pypgstac import __version__ as pgstac_version from pypgstac.db import PgstacDB from pypgstac.migrate import Migrate from pytest_postgresql.janitor import DatabaseJanitor @@ -26,6 +27,7 @@ CollectionSearchExtension, FieldsExtension, FilterExtension, + OffsetPaginationExtension, SortExtension, TokenPaginationExtension, TransactionExtension, @@ -47,6 +49,12 @@ logger = logging.getLogger(__name__) +requires_pgstac_0_9_2 = pytest.mark.skipif( + tuple(map(int, pgstac_version.split("."))) < (0, 9, 2), + reason="PgSTAC>=0.9.2 required", +) + + @pytest.fixture(scope="session") def event_loop(): return asyncio.get_event_loop() @@ -140,6 +148,7 @@ def api_client(request, database): SortExtension(), FieldsExtension(), FilterExtension(client=FiltersClient()), + OffsetPaginationExtension(), ] collection_search_extension = CollectionSearchExtension.from_extensions( collection_extensions @@ -259,3 +268,48 @@ async def load_test2_item(app_client, load_test_data, load_test2_collection): ) assert resp.status_code == 201 return Item.model_validate(resp.json()) + + +@pytest.fixture( + scope="session", +) +def api_client_no_ext(database): + api_settings = Settings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + testing=True, + ) + return StacApi( + settings=api_settings, + extensions=[ + TransactionExtension(client=TransactionsClient(), settings=api_settings) + ], + client=CoreCrudClient(), + ) + + +@pytest.fixture(scope="function") +async def app_no_ext(api_client_no_ext): + logger.info("Creating app Fixture") + time.time() + app = api_client_no_ext.app + await connect_to_db(app) + + yield app + + await close_db_connection(app) + + logger.info("Closed Pools.") + + +@pytest.fixture(scope="function") +async def app_client_no_ext(app_no_ext): + logger.info("creating app_client") + async with AsyncClient( + transport=ASGITransport(app=app_no_ext), base_url="http://test" + ) as c: + yield c diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 634747bc..dec20881 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -4,6 +4,8 @@ import pytest from stac_pydantic import Collection +from ..conftest import requires_pgstac_0_9_2 + async def test_create_collection(app_client, load_test_data: Callable): in_json = load_test_data("test_collection.json") @@ -303,3 +305,203 @@ async def test_get_collections_search( "/collections", ) assert len(resp.json()["collections"]) == 2 + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_all_collections_with_pagination(app_client, load_test_data): + data = load_test_data("test_collection.json") + collection_id = data["id"] + for ii in range(0, 12): + data["id"] = collection_id + f"_{ii}" + resp = await app_client.post( + "/collections", + json=data, + ) + assert resp.status_code == 201 + + resp = await app_client.get("/collections") + cols = resp.json()["collections"] + assert len(cols) == 10 + links = resp.json()["links"] + assert len(links) == 3 + assert {"root", "self", "next"} == {link["rel"] for link in links} + + resp = await app_client.get("/collections", params={"limit": 12}) + cols = resp.json()["collections"] + assert len(cols) == 12 + links = resp.json()["links"] + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_all_collections_without_pagination(app_client_no_ext, load_test_data): + data = load_test_data("test_collection.json") + collection_id = data["id"] + for ii in range(0, 12): + data["id"] = collection_id + f"_{ii}" + resp = await app_client_no_ext.post( + "/collections", + json=data, + ) + assert resp.status_code == 201 + + resp = await app_client_no_ext.get("/collections") + cols = resp.json()["collections"] + assert len(cols) == 12 + links = resp.json()["links"] + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_get_collections_search_pagination( + app_client, load_test_collection, load_test2_collection +): + resp = await app_client.get("/collections") + cols = resp.json()["collections"] + assert len(cols) == 2 + links = resp.json()["links"] + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + ################### + # limit should be positive + resp = await app_client.get("/collections", params={"limit": 0}) + assert resp.status_code == 400 + + ################### + # limit=1, should have a `next` link + resp = await app_client.get( + "/collections", + params={"limit": 1}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 1 + assert cols[0]["id"] == load_test_collection["id"] + assert len(links) == 3 + assert {"root", "self", "next"} == {link["rel"] for link in links} + next_link = list(filter(lambda link: link["rel"] == "next", links))[0] + assert next_link["href"].endswith("?limit=1&offset=1") + + ################### + # limit=2, there should not be a next link + resp = await app_client.get( + "/collections", + params={"limit": 2}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 2 + assert cols[0]["id"] == load_test_collection["id"] + assert cols[1]["id"] == load_test2_collection.id + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + ################### + # limit=3, there should not be a next/previous link + resp = await app_client.get( + "/collections", + params={"limit": 3}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 2 + assert cols[0]["id"] == load_test_collection["id"] + assert cols[1]["id"] == load_test2_collection.id + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + ################### + # offset=3, because there are 2 collections, we should not have `next` or `prev` links + resp = await app_client.get( + "/collections", + params={"offset": 3}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 0 + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + ################### + # offset=3,limit=1 + resp = await app_client.get( + "/collections", + params={"limit": 1, "offset": 3}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 0 + assert len(links) == 3 + assert {"root", "self", "previous"} == {link["rel"] for link in links} + prev_link = list(filter(lambda link: link["rel"] == "previous", links))[0] + assert prev_link["href"].endswith("?limit=1&offset=2") + + ################### + # limit=2, offset=3, there should not be a next link + resp = await app_client.get( + "/collections", + params={"limit": 2, "offset": 3}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 0 + assert len(links) == 3 + assert {"root", "self", "previous"} == {link["rel"] for link in links} + prev_link = list(filter(lambda link: link["rel"] == "previous", links))[0] + assert prev_link["href"].endswith("?limit=2&offset=1") + + ################### + # offset=1,limit=1 should have a `previous` link + resp = await app_client.get( + "/collections", + params={"offset": 1, "limit": 1}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 1 + assert cols[0]["id"] == load_test2_collection.id + assert len(links) == 3 + assert {"root", "self", "previous"} == {link["rel"] for link in links} + prev_link = list(filter(lambda link: link["rel"] == "previous", links))[0] + assert "offset" in prev_link["href"] + + ################### + # offset=0, should not have next/previous link + resp = await app_client.get( + "/collections", + params={"offset": 0}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 2 + assert len(links) == 2 + assert {"root", "self"} == {link["rel"] for link in links} + + +@requires_pgstac_0_9_2 +@pytest.mark.xfail(strict=False) +@pytest.mark.asyncio +async def test_get_collections_search_offset_1( + app_client, load_test_collection, load_test2_collection +): + # BUG: pgstac doesn't return a `prev` link when limit is not set + # offset=1, should have a `previous` link + resp = await app_client.get( + "/collections", + params={"offset": 1}, + ) + cols = resp.json()["collections"] + links = resp.json()["links"] + assert len(cols) == 1 + assert cols[0]["id"] == load_test2_collection.id + assert len(links) == 3 + assert {"root", "self", "previous"} == {link["rel"] for link in links} + prev_link = list(filter(lambda link: link["rel"] == "previous", links))[0] + # offset=0 should not be in the previous link (because it's useless) + assert "offset" not in prev_link["href"]