diff --git a/CHANGES.md b/CHANGES.md index 32c2f1ea..06a2d9e9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,7 @@ ### Fixed - fix root-path handling when setting via env var or on app instance +- Allow `q` parameter to be a `str` not a `list[str]` for Advanced Free-Text extension ### Changed diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index b8e39331..cc650d22 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -24,7 +24,6 @@ CollectionSearchExtension, CollectionSearchFilterExtension, FieldsExtension, - FreeTextExtension, ItemCollectionFilterExtension, OffsetPaginationExtension, SearchFilterExtension, @@ -42,7 +41,7 @@ from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from stac_fastapi.pgstac.extensions import QueryExtension +from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension from stac_fastapi.pgstac.extensions.filter import FiltersClient from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7854ad0d..d159ba67 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -54,8 +54,7 @@ async def all_collections( # noqa: C901 sortby: Optional[str] = None, filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, - q: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> Collections: """Cross catalog search (GET). @@ -86,9 +85,15 @@ async def all_collections( # noqa: C901 sortby=sortby, filter_query=filter_expr, filter_lang=filter_lang, - q=q, + **kwargs, ) + # NOTE: `FreeTextExtension` - pgstac will only accept `str` so we need to + # join the list[str] with ` OR ` + # ref: https://github.com/stac-utils/stac-fastapi-pgstac/pull/263 + if q := clean_args.pop("q", None): + clean_args["q"] = " OR ".join(q) if isinstance(q, list) else q + async with request.app.state.get_connection(request, "r") as conn: q, p = render( """ @@ -157,7 +162,10 @@ async def all_collections( # noqa: C901 ) async def get_collection( - self, collection_id: str, request: Request, **kwargs + self, + collection_id: str, + request: Request, + **kwargs: Any, ) -> Collection: """Get collection by id. @@ -202,7 +210,9 @@ async def get_collection( return Collection(**collection) async def _get_base_item( - self, collection_id: str, request: Request + self, + collection_id: str, + request: Request, ) -> Dict[str, Any]: """Get the base item of a collection for use in rehydrating full item collection properties. @@ -359,7 +369,7 @@ async def item_collection( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, token: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> ItemCollection: """Get all items from a specific collection. @@ -391,6 +401,7 @@ async def item_collection( filter_lang=filter_lang, fields=fields, sortby=sortby, + **kwargs, ) try: @@ -417,7 +428,11 @@ async def item_collection( return ItemCollection(**item_collection) async def get_item( - self, item_id: str, collection_id: str, request: Request, **kwargs + self, + item_id: str, + collection_id: str, + request: Request, + **kwargs: Any, ) -> Item: """Get item by id. @@ -445,7 +460,10 @@ async def get_item( return Item(**item_collection["features"][0]) async def post_search( - self, search_request: PgstacSearch, request: Request, **kwargs + self, + search_request: PgstacSearch, + request: Request, + **kwargs: Any, ) -> ItemCollection: """Cross catalog search (POST). @@ -489,7 +507,7 @@ async def get_search( filter_expr: Optional[str] = None, filter_lang: Optional[str] = None, token: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> ItemCollection: """Cross catalog search (GET). @@ -516,6 +534,7 @@ async def get_search( sortby=sortby, filter_query=filter_expr, filter_lang=filter_lang, + **kwargs, ) try: @@ -550,7 +569,8 @@ def _clean_search_args( # noqa: C901 sortby: Optional[str] = None, filter_query: Optional[str] = None, filter_lang: Optional[str] = None, - q: Optional[List[str]] = None, + q: Optional[Union[str, List[str]]] = None, + **kwargs: Any, ) -> Dict[str, Any]: """Clean up search arguments to match format expected by pgstac""" if filter_query: @@ -596,7 +616,7 @@ def _clean_search_args( # noqa: C901 base_args["fields"] = {"include": includes, "exclude": excludes} if q: - base_args["q"] = " OR ".join(q) + base_args["q"] = q # Remove None values from dict clean = {} diff --git a/stac_fastapi/pgstac/extensions/__init__.py b/stac_fastapi/pgstac/extensions/__init__.py index 00544179..6c2812b6 100644 --- a/stac_fastapi/pgstac/extensions/__init__.py +++ b/stac_fastapi/pgstac/extensions/__init__.py @@ -1,6 +1,7 @@ """pgstac extension customisations.""" from .filter import FiltersClient +from .free_text import FreeTextExtension from .query import QueryExtension -__all__ = ["QueryExtension", "FiltersClient"] +__all__ = ["QueryExtension", "FiltersClient", "FreeTextExtension"] diff --git a/stac_fastapi/pgstac/extensions/free_text.py b/stac_fastapi/pgstac/extensions/free_text.py new file mode 100644 index 00000000..cadab7fe --- /dev/null +++ b/stac_fastapi/pgstac/extensions/free_text.py @@ -0,0 +1,31 @@ +"""Free-Text model for PgSTAC.""" + +from typing import List, Optional + +from pydantic import BaseModel, Field +from pydantic.functional_serializers import PlainSerializer +from stac_fastapi.extensions.core.free_text import ( + FreeTextExtension as FreeTextExtensionBase, +) +from typing_extensions import Annotated + + +class FreeTextExtensionPostRequest(BaseModel): + """Free-text Extension POST request model.""" + + q: Annotated[ + Optional[List[str]], + PlainSerializer(lambda x: " OR ".join(x), return_type=str, when_used="json"), + ] = Field( + None, + description="Parameter to perform free-text queries against STAC metadata", + ) + + +class FreeTextExtension(FreeTextExtensionBase): + """FreeText Extension. + + Override the POST request model to add custom serialization + """ + + POST = FreeTextExtensionPostRequest diff --git a/tests/conftest.py b/tests/conftest.py index 05846bec..d3495936 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import json import logging import os -import time from typing import Callable, Dict from urllib.parse import quote_plus as quote from urllib.parse import urljoin @@ -26,7 +25,7 @@ CollectionSearchExtension, CollectionSearchFilterExtension, FieldsExtension, - FreeTextExtension, + FreeTextAdvancedExtension, ItemCollectionFilterExtension, OffsetPaginationExtension, SearchFilterExtension, @@ -44,7 +43,7 @@ from stac_fastapi.pgstac.config import PostgresSettings, Settings from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db -from stac_fastapi.pgstac.extensions import QueryExtension +from stac_fastapi.pgstac.extensions import FreeTextExtension, QueryExtension from stac_fastapi.pgstac.extensions.filter import FiltersClient from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch @@ -139,6 +138,7 @@ def api_client(request): FieldsExtension(), SearchFilterExtension(client=FiltersClient()), TokenPaginationExtension(), + FreeTextExtension(), # not recommended by PgSTAC ] application_extensions.extend(search_extensions) @@ -167,6 +167,7 @@ def api_client(request): FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]), ItemCollectionFilterExtension(client=FiltersClient()), TokenPaginationExtension(), + FreeTextExtension(), # not recommended by PgSTAC ] application_extensions.extend(item_collection_extensions) @@ -207,7 +208,6 @@ async def app(api_client, database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() app = api_client.app await connect_to_db( app, @@ -314,7 +314,6 @@ async def app_no_ext(database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() await connect_to_db( api_client_no_ext.app, postgres_settings=postgres_settings, @@ -354,7 +353,6 @@ async def app_no_transaction(database): pgdatabase=database.dbname, ) logger.info("Creating app Fixture") - time.time() await connect_to_db( api.app, postgres_settings=postgres_settings, @@ -402,3 +400,57 @@ async def default_client(default_app): transport=ASGITransport(app=default_app), base_url="http://test" ) as c: yield c + + +@pytest.fixture(scope="function") +async def app_advanced_freetext(database): + """Default stac-fastapi-pgstac application without only the transaction extensions.""" + api_settings = Settings(testing=True) + + application_extensions = [ + TransactionExtension(client=TransactionsClient(), settings=api_settings) + ] + + collection_extensions = [ + FreeTextAdvancedExtension(), + OffsetPaginationExtension(), + ] + collection_search_extension = CollectionSearchExtension.from_extensions( + collection_extensions + ) + application_extensions.append(collection_search_extension) + + app = StacApi( + settings=api_settings, + extensions=application_extensions, + client=CoreCrudClient(), + health_check=health_check, + collections_get_request_model=collection_search_extension.GET, + ) + + postgres_settings = PostgresSettings( + pguser=database.user, + pgpassword=database.password, + pghost=database.host, + pgport=database.port, + pgdatabase=database.dbname, + ) + logger.info("Creating app Fixture") + await connect_to_db( + app.app, + postgres_settings=postgres_settings, + add_write_connection_pool=True, + ) + yield app.app + await close_db_connection(app.app) + + logger.info("Closed Pools.") + + +@pytest.fixture(scope="function") +async def app_client_advanced_freetext(app_advanced_freetext): + logger.info("creating app_client") + async with AsyncClient( + transport=ASGITransport(app=app_advanced_freetext), base_url="http://test" + ) as c: + yield c diff --git a/tests/data/test_item.json b/tests/data/test_item.json index 1c68b959..cac06d66 100644 --- a/tests/data/test_item.json +++ b/tests/data/test_item.json @@ -34,6 +34,7 @@ "type": "Polygon" }, "properties": { + "description": "Landat 8 imagery radiometrically calibrated and orthorectified using gound points and Digital Elevation Model (DEM) data to correct relief displacement.", "datetime": "2020-02-12T12:30:22Z", "landsat:scene_id": "LC82081612020043LGN00", "landsat:row": "161", diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 745d4230..013f9baa 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -365,6 +365,71 @@ async def test_collection_search_freetext( assert resp.json()["collections"][0]["id"] == load_test2_collection.id resp = await app_client.get( + "/collections", + params={"q": "temperature,calibrated"}, + ) + assert resp.json()["numberReturned"] == 2 + assert resp.json()["numberMatched"] == 2 + assert len(resp.json()["collections"]) == 2 + + resp = await app_client.get( + "/collections", + params={"q": "temperature,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client.get( + "/collections", + params={"q": "nosuchthing"}, + ) + assert len(resp.json()["collections"]) == 0 + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_collection_search_freetext_advanced( + app_client_advanced_freetext, load_test_collection, load_test2_collection +): + # free-text + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature,calibrated"}, + ) + assert resp.json()["numberReturned"] == 2 + assert resp.json()["numberMatched"] == 2 + assert len(resp.json()["collections"]) == 2 + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( + "/collections", + params={"q": "temperature OR yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["numberMatched"] == 1 + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + resp = await app_client_advanced_freetext.get( "/collections", params={"q": "nosuchthing"}, ) diff --git a/tests/resources/test_item.py b/tests/resources/test_item.py index 4ea70193..490d652a 100644 --- a/tests/resources/test_item.py +++ b/tests/resources/test_item.py @@ -18,6 +18,8 @@ from stac_fastapi.pgstac.models.links import CollectionLinks +from ..conftest import requires_pgstac_0_9_2 + async def test_create_collection(app_client, load_test_data: Callable): in_json = load_test_data("test_collection.json") @@ -1689,3 +1691,34 @@ async def test_get_search_link_media(app_client): assert len(links) == 2 get_self_link = next((link for link in links if link["rel"] == "self"), None) assert get_self_link["type"] == "application/geo+json" + + +@requires_pgstac_0_9_2 +@pytest.mark.asyncio +async def test_item_search_freetext(app_client, load_test_data, load_test_collection): + test_item = load_test_data("test_item.json") + resp = await app_client.post( + f"/collections/{test_item['collection']}/items", json=test_item + ) + assert resp.status_code == 201 + + # free-text + resp = await app_client.get( + "/search", + params={"q": "orthorectified"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["features"][0]["id"] == "test-item" + + resp = await app_client.get( + "/search", + params={"q": "orthorectified,yo"}, + ) + assert resp.json()["numberReturned"] == 1 + assert resp.json()["features"][0]["id"] == "test-item" + + resp = await app_client.get( + "/search", + params={"q": "yo"}, + ) + assert resp.json()["numberReturned"] == 0