diff --git a/CHANGES.md b/CHANGES.md index 02363fc28..bf70ff60c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +* Add `from_extensions()` method to `CollectionSearchExtension` and `CollectionSearchPostExtension` extensions to build the class based on a list of available extensions. + ## [3.0.1] - 2024-08-27 ### Changed diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py index 2927cd822..2a5f7cf4d 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/collection_search/collection_search.py @@ -1,5 +1,6 @@ """Collection-Search extension.""" +import warnings from enum import Enum from typing import List, Optional, Union @@ -8,7 +9,7 @@ from stac_pydantic.api.collections import Collections from stac_pydantic.shared import MimeTypes -from stac_fastapi.api.models import GeoJSONResponse +from stac_fastapi.api.models import GeoJSONResponse, create_request_model from stac_fastapi.api.routes import create_async_endpoint from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.extension import ApiExtension @@ -71,6 +72,48 @@ def register(self, app: FastAPI) -> None: """ pass + @classmethod + def from_extensions( + cls, + extensions: List[ApiExtension], + schema_href: Optional[str] = None, + ) -> "CollectionSearchExtension": + """Create CollectionSearchExtension object from extensions.""" + supported_extensions = { + "FreeTextExtension": ConformanceClasses.FREETEXT, + "FreeTextAdvancedExtension": ConformanceClasses.FREETEXT, + "QueryExtension": ConformanceClasses.QUERY, + "SortExtension": ConformanceClasses.SORT, + "FieldsExtension": ConformanceClasses.FIELDS, + "FilterExtension": ConformanceClasses.FILTER, + } + conformance_classes = [ + ConformanceClasses.COLLECTIONSEARCH, + ConformanceClasses.BASIS, + ] + for ext in extensions: + conf = supported_extensions.get(ext.__class__.__name__, None) + if not conf: + warnings.warn( + f"Conformance class for `{ext.__class__.__name__}` extension not found.", # noqa: E501 + UserWarning, + ) + else: + conformance_classes.append(conf) + + get_request_model = create_request_model( + model_name="CollectionsGetRequest", + base_model=BaseCollectionSearchGetRequest, + extensions=extensions, + request_type="GET", + ) + + return cls( + GET=get_request_model, + conformance_classes=conformance_classes, + schema_href=schema_href, + ) + @attr.s class CollectionSearchPostExtension(CollectionSearchExtension): @@ -132,3 +175,60 @@ def register(self, app: FastAPI) -> None: endpoint=create_async_endpoint(self.client.post_all_collections, self.POST), ) app.include_router(self.router) + + @classmethod + def from_extensions( + cls, + extensions: List[ApiExtension], + *, + client: Union[AsyncBaseCollectionSearchClient, BaseCollectionSearchClient], + settings: ApiSettings, + schema_href: Optional[str] = None, + router: Optional[APIRouter] = None, + ) -> "CollectionSearchPostExtension": + """Create CollectionSearchPostExtension object from extensions.""" + supported_extensions = { + "FreeTextExtension": ConformanceClasses.FREETEXT, + "FreeTextAdvancedExtension": ConformanceClasses.FREETEXT, + "QueryExtension": ConformanceClasses.QUERY, + "SortExtension": ConformanceClasses.SORT, + "FieldsExtension": ConformanceClasses.FIELDS, + "FilterExtension": ConformanceClasses.FILTER, + } + conformance_classes = [ + ConformanceClasses.COLLECTIONSEARCH, + ConformanceClasses.BASIS, + ] + for ext in extensions: + conf = supported_extensions.get(ext.__class__.__name__, None) + if not conf: + warnings.warn( + f"Conformance class for `{ext.__class__.__name__}` extension not found.", # noqa: E501 + UserWarning, + ) + else: + conformance_classes.append(conf) + + get_request_model = create_request_model( + model_name="CollectionsGetRequest", + base_model=BaseCollectionSearchGetRequest, + extensions=extensions, + request_type="GET", + ) + + post_request_model = create_request_model( + model_name="CollectionsPostRequest", + base_model=BaseCollectionSearchPostRequest, + extensions=extensions, + request_type="POST", + ) + + return cls( + client=client, + settings=settings, + GET=get_request_model, + POST=post_request_model, + conformance_classes=conformance_classes, + router=router or APIRouter(), + schema_href=schema_href, + ) diff --git a/stac_fastapi/extensions/tests/test_collection_search.py b/stac_fastapi/extensions/tests/test_collection_search.py index edc292210..b23219956 100644 --- a/stac_fastapi/extensions/tests/test_collection_search.py +++ b/stac_fastapi/extensions/tests/test_collection_search.py @@ -2,13 +2,21 @@ from urllib.parse import quote_plus import attr +import pytest from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import create_request_model from stac_fastapi.extensions.core import ( + AggregationExtension, CollectionSearchExtension, CollectionSearchPostExtension, + FieldsExtension, + FilterExtension, + FreeTextAdvancedExtension, + FreeTextExtension, + QueryExtension, + SortExtension, ) from stac_fastapi.extensions.core.collection_search import ConformanceClasses from stac_fastapi.extensions.core.collection_search.client import ( @@ -302,8 +310,8 @@ def test_collection_search_extension_post_models(): client=DummyCoreClient(), extensions=[ CollectionSearchPostExtension( - settings=settings, client=DummyPostClient(), + settings=settings, GET=get_request_model, POST=post_request_model, conformance_classes=[ @@ -392,3 +400,112 @@ def test_collection_search_extension_post_models(): assert response_dict["query"] assert response_dict["sortby"] assert response_dict["fields"] + + +@pytest.mark.parametrize( + "extensions", + [ + # with FreeTextExtension + [ + FieldsExtension(), + FilterExtension(), + FreeTextExtension(), + QueryExtension(), + SortExtension(), + ], + # with FreeTextAdvancedExtension + [ + FieldsExtension(), + FilterExtension(), + FreeTextAdvancedExtension(), + QueryExtension(), + SortExtension(), + ], + ], +) +def test_from_extensions_methods(extensions): + """ + Make sure `from_extensions` create the correct + models and adds desired conformances classes. + """ + ext = CollectionSearchExtension.from_extensions( + extensions, + ) + collection_search = ext.GET() + assert collection_search.__class__.__name__ == "CollectionsGetRequest" + assert hasattr(collection_search, "bbox") + assert hasattr(collection_search, "datetime") + assert hasattr(collection_search, "limit") + assert hasattr(collection_search, "fields") + assert hasattr(collection_search, "q") + assert hasattr(collection_search, "sortby") + assert hasattr(collection_search, "filter") + assert ext.conformance_classes == [ + ConformanceClasses.COLLECTIONSEARCH, + ConformanceClasses.BASIS, + ConformanceClasses.FIELDS, + ConformanceClasses.FILTER, + ConformanceClasses.FREETEXT, + ConformanceClasses.QUERY, + ConformanceClasses.SORT, + ] + + ext = CollectionSearchPostExtension.from_extensions( + extensions, + client=DummyPostClient(), + settings=ApiSettings(), + ) + collection_search = ext.POST() + assert collection_search.__class__.__name__ == "CollectionsPostRequest" + assert hasattr(collection_search, "bbox") + assert hasattr(collection_search, "datetime") + assert hasattr(collection_search, "limit") + assert hasattr(collection_search, "fields") + assert hasattr(collection_search, "q") + assert hasattr(collection_search, "sortby") + assert hasattr(collection_search, "filter") + assert ext.conformance_classes == [ + ConformanceClasses.COLLECTIONSEARCH, + ConformanceClasses.BASIS, + ConformanceClasses.FIELDS, + ConformanceClasses.FILTER, + ConformanceClasses.FREETEXT, + ConformanceClasses.QUERY, + ConformanceClasses.SORT, + ] + + +def test_from_extensions_methods_invalid(): + """Should raise warnings for invalid extensions.""" + extensions = [ + AggregationExtension(), + ] + with pytest.warns((UserWarning)): + ext = CollectionSearchExtension.from_extensions( + extensions, + ) + collection_search = ext.GET() + assert collection_search.__class__.__name__ == "CollectionsGetRequest" + assert hasattr(collection_search, "bbox") + assert hasattr(collection_search, "datetime") + assert hasattr(collection_search, "limit") + assert ext.conformance_classes == [ + ConformanceClasses.COLLECTIONSEARCH, + ConformanceClasses.BASIS, + ] + + with pytest.warns((UserWarning)): + ext = CollectionSearchPostExtension.from_extensions( + extensions, + client=DummyPostClient(), + settings=ApiSettings(), + ) + collection_search = ext.POST() + assert collection_search.__class__.__name__ == "CollectionsPostRequest" + assert hasattr(collection_search, "bbox") + assert hasattr(collection_search, "datetime") + assert hasattr(collection_search, "limit") + assert ext.conformance_classes == [ + ConformanceClasses.COLLECTIONSEARCH, + ConformanceClasses.BASIS, + ]