Skip to content

Commit 6e83227

Browse files
committed
new route scratch
1 parent f9c5aa9 commit 6e83227

File tree

6 files changed

+274
-20
lines changed

6 files changed

+274
-20
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ async def all_collections(
283283
if parsed_sort:
284284
sort = parsed_sort
285285

286-
print("sort: ", sort)
286+
# sort is now ready for use
287287
# Convert q to a list if it's a string
288288
q_list = None
289289
if q is not None:
@@ -404,11 +404,13 @@ async def post_all_collections(
404404
Returns:
405405
A Collections object containing all the collections in the database and links to various resources.
406406
"""
407+
# Debug print
408+
print("search_request: ", search_request)
407409
# Set the postbody attribute on the request object for PagingLinks
408410
request.postbody = search_request.model_dump(exclude_unset=True)
409-
# Convert fields parameter from POST format to all_collections format
410-
fields = None
411411

412+
fields = None
413+
# Check for fields attribute (legacy format)
412414
if hasattr(search_request, "fields") and search_request.fields:
413415
fields = []
414416

@@ -428,8 +430,29 @@ async def post_all_collections(
428430
for field in search_request.fields.exclude:
429431
fields.append(f"-{field}")
430432

433+
# Check for field attribute (ExtendedSearch format)
434+
if hasattr(search_request, "field") and search_request.field:
435+
fields = []
436+
437+
# Handle include fields
438+
if (
439+
hasattr(search_request.field, "includes")
440+
and search_request.field.includes
441+
):
442+
for field in search_request.field.includes:
443+
fields.append(f"+{field}")
444+
445+
# Handle exclude fields
446+
if (
447+
hasattr(search_request.field, "excludes")
448+
and search_request.field.excludes
449+
):
450+
for field in search_request.field.excludes:
451+
fields.append(f"-{field}")
452+
431453
# Convert sortby parameter from POST format to all_collections format
432454
sortby = None
455+
# Check for sortby attribute
433456
if hasattr(search_request, "sortby") and search_request.sortby:
434457
# Create a list of sort strings in the format expected by all_collections
435458
sortby = []
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
"""elasticsearch extensions modifications."""
22

3+
from .collections_search import CollectionsSearchEndpointExtension
34
from .query import Operator, QueryableTypes, QueryExtension
45

5-
__all__ = ["Operator", "QueryableTypes", "QueryExtension"]
6+
__all__ = [
7+
"Operator",
8+
"QueryableTypes",
9+
"QueryExtension",
10+
"CollectionsSearchEndpointExtension",
11+
]
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""Collections search extension."""
2+
3+
from typing import List, Optional, Type, Union
4+
5+
from fastapi import APIRouter, FastAPI, Request
6+
from fastapi.responses import JSONResponse
7+
from pydantic import BaseModel
8+
from starlette.responses import Response
9+
10+
from stac_fastapi.api.models import APIRequest
11+
from stac_fastapi.types.core import BaseCoreClient
12+
from stac_fastapi.types.extension import ApiExtension
13+
from stac_fastapi.types.stac import Collections
14+
15+
16+
class CollectionsSearchEndpointExtension(ApiExtension):
17+
"""Collections search endpoint extension.
18+
19+
This extension adds a dedicated /collections-search endpoint for collection search operations.
20+
"""
21+
22+
def __init__(
23+
self,
24+
client: Optional[BaseCoreClient] = None,
25+
settings: dict = None,
26+
GET: Optional[Type[Union[BaseModel, APIRequest]]] = None,
27+
POST: Optional[Type[Union[BaseModel, APIRequest]]] = None,
28+
conformance_classes: Optional[List[str]] = None,
29+
):
30+
"""Initialize the extension.
31+
32+
Args:
33+
client: Optional BaseCoreClient instance to use for this extension.
34+
settings: Dictionary of settings to pass to the extension.
35+
GET: Optional GET request model.
36+
POST: Optional POST request model.
37+
conformance_classes: Optional list of conformance classes to add to the API.
38+
"""
39+
super().__init__()
40+
self.client = client
41+
self.settings = settings or {}
42+
self.GET = GET
43+
self.POST = POST
44+
self.conformance_classes = conformance_classes or []
45+
self.router = APIRouter()
46+
self.create_endpoints()
47+
48+
def register(self, app: FastAPI) -> None:
49+
"""Register the extension with a FastAPI application.
50+
51+
Args:
52+
app: target FastAPI application.
53+
54+
Returns:
55+
None
56+
"""
57+
app.include_router(self.router)
58+
59+
def create_endpoints(self) -> None:
60+
"""Create endpoints for the extension."""
61+
if self.GET:
62+
self.router.add_api_route(
63+
name="Get Collections Search",
64+
path="/collections-search",
65+
response_model=None,
66+
response_class=JSONResponse,
67+
methods=["GET"],
68+
endpoint=self.collections_search_get_endpoint,
69+
**(self.settings if isinstance(self.settings, dict) else {}),
70+
)
71+
72+
if self.POST:
73+
self.router.add_api_route(
74+
name="Post Collections Search",
75+
path="/collections-search",
76+
response_model=None,
77+
response_class=JSONResponse,
78+
methods=["POST"],
79+
endpoint=self.collections_search_post_endpoint,
80+
**(self.settings if isinstance(self.settings, dict) else {}),
81+
)
82+
83+
async def collections_search_get_endpoint(
84+
self, request: Request
85+
) -> Union[Collections, Response]:
86+
"""GET /collections-search endpoint.
87+
88+
Args:
89+
request: Request object.
90+
91+
Returns:
92+
Collections: Collections object.
93+
"""
94+
# Extract query parameters from the request
95+
params = dict(request.query_params)
96+
97+
# Convert query parameters to appropriate types
98+
if "limit" in params:
99+
try:
100+
params["limit"] = int(params["limit"])
101+
except ValueError:
102+
pass
103+
104+
# Handle fields parameter
105+
if "fields" in params:
106+
fields_str = params.pop("fields")
107+
fields = fields_str.split(",")
108+
params["fields"] = fields
109+
110+
# Handle sortby parameter
111+
if "sortby" in params:
112+
sortby_str = params.pop("sortby")
113+
sortby = sortby_str.split(",")
114+
params["sortby"] = sortby
115+
116+
collections = await self.client.all_collections(request=request, **params)
117+
return collections
118+
119+
async def collections_search_post_endpoint(
120+
self, request: Request, body: dict
121+
) -> Union[Collections, Response]:
122+
"""POST /collections-search endpoint.
123+
124+
Args:
125+
request: Request object.
126+
body: Search request body.
127+
128+
Returns:
129+
Collections: Collections object.
130+
"""
131+
from stac_pydantic.api.search import ExtendedSearch
132+
133+
# Convert the dict to an ExtendedSearch model
134+
search_request = ExtendedSearch.model_validate(body)
135+
136+
# Check if fields are present in the body
137+
if "fields" in body:
138+
# Extract fields from body and add them to search_request
139+
if hasattr(search_request, "field"):
140+
from stac_pydantic.api.extensions.fields import FieldsExtension
141+
142+
fields_data = body["fields"]
143+
search_request.field = FieldsExtension(
144+
includes=fields_data.get("include"),
145+
excludes=fields_data.get("exclude"),
146+
)
147+
148+
# Set the postbody on the request for pagination links
149+
request.postbody = body
150+
151+
collections = await self.client.post_all_collections(
152+
search_request=search_request, request=request
153+
)
154+
return collections
155+
156+
@classmethod
157+
def from_extensions(
158+
cls, extensions: List[ApiExtension]
159+
) -> "CollectionsSearchEndpointExtension":
160+
"""Create a CollectionsSearchEndpointExtension from a list of extensions.
161+
162+
Args:
163+
extensions: List of extensions to include in the CollectionsSearchEndpointExtension.
164+
165+
Returns:
166+
CollectionsSearchEndpointExtension: A new CollectionsSearchEndpointExtension instance.
167+
"""
168+
from stac_fastapi.api.models import (
169+
create_get_request_model,
170+
create_post_request_model,
171+
)
172+
173+
get_model = create_get_request_model(extensions)
174+
post_model = create_post_request_model(extensions)
175+
176+
return cls(
177+
GET=get_model,
178+
POST=post_model,
179+
conformance_classes=[
180+
ext.conformance_classes
181+
for ext in extensions
182+
if hasattr(ext, "conformance_classes")
183+
],
184+
)

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
EsAggregationExtensionGetRequest,
2424
EsAggregationExtensionPostRequest,
2525
)
26+
from stac_fastapi.core.extensions.collections_search import (
27+
CollectionsSearchEndpointExtension,
28+
)
2629
from stac_fastapi.core.extensions.fields import FieldsExtension
2730
from stac_fastapi.core.rate_limit import setup_rate_limit
2831
from stac_fastapi.core.route_dependencies import get_route_dependencies
@@ -163,8 +166,31 @@
163166
],
164167
)
165168

169+
# Initialize collections-search endpoint extension
170+
collections_search_endpoint_ext = CollectionsSearchEndpointExtension(
171+
client=CoreClient(
172+
database=database_logic,
173+
session=session,
174+
post_request_model=collection_search_post_request_model,
175+
landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"),
176+
),
177+
settings=settings,
178+
GET=collections_get_request_model,
179+
POST=collection_search_post_request_model,
180+
conformance_classes=[
181+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search",
182+
"http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query",
183+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#filter",
184+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#free-text",
185+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#query",
186+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#sort",
187+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#fields",
188+
],
189+
)
190+
166191
extensions.append(collection_search_ext)
167192
extensions.append(collection_search_post_ext)
193+
extensions.append(collections_search_endpoint_ext)
168194

169195
database_logic.extensions = [type(ext).__name__ for ext in extensions]
170196

stac_fastapi/opensearch/stac_fastapi/opensearch/app.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
EsAggregationExtensionGetRequest,
2424
EsAggregationExtensionPostRequest,
2525
)
26+
from stac_fastapi.core.extensions.collections_search import (
27+
CollectionsSearchEndpointExtension,
28+
)
2629
from stac_fastapi.core.extensions.fields import FieldsExtension
2730
from stac_fastapi.core.rate_limit import setup_rate_limit
2831
from stac_fastapi.core.route_dependencies import get_route_dependencies
@@ -163,8 +166,31 @@
163166
],
164167
)
165168

169+
# Initialize collections-search endpoint extension
170+
collections_search_endpoint_ext = CollectionsSearchEndpointExtension(
171+
client=CoreClient(
172+
database=database_logic,
173+
session=session,
174+
post_request_model=collection_search_post_request_model,
175+
landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"),
176+
),
177+
settings=settings,
178+
GET=collections_get_request_model,
179+
POST=collection_search_post_request_model,
180+
conformance_classes=[
181+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search",
182+
"http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query",
183+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#filter",
184+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#free-text",
185+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#query",
186+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#sort",
187+
"https://api.stacspec.org/v1.0.0-rc.1/collection-search#fields",
188+
],
189+
)
190+
166191
extensions.append(collection_search_ext)
167192
extensions.append(collection_search_post_ext)
193+
extensions.append(collections_search_endpoint_ext)
168194

169195
database_logic.extensions = [type(ext).__name__ for ext in extensions]
170196

stac_fastapi/tests/api/test_api_search_collections.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -601,13 +601,8 @@ async def test_collections_number_matched_returned(app_client, txn_client, ctx):
601601

602602

603603
@pytest.mark.asyncio
604-
async def test_collections_post(app_client, txn_client, ctx, monkeypatch):
605-
"""Verify POST /collections endpoint works."""
606-
# Turn off the transaction extension to avoid conflict with collections POST endpoint
607-
import os
608-
609-
original_value = os.environ.get("ENABLE_TRANSACTIONS_EXTENSIONS")
610-
monkeypatch.setenv("ENABLE_TRANSACTIONS_EXTENSIONS", "False")
604+
async def test_collections_post(app_client, txn_client, ctx):
605+
"""Verify POST /collections-search endpoint works."""
611606

612607
# Create multiple collections with different ids
613608
base_collection = ctx.collection
@@ -627,7 +622,7 @@ async def test_collections_post(app_client, txn_client, ctx, monkeypatch):
627622

628623
# Test basic POST search
629624
resp = await app_client.post(
630-
"/collections",
625+
"/collections-search",
631626
json={"limit": 5},
632627
)
633628
assert resp.status_code == 200
@@ -649,7 +644,7 @@ async def test_collections_post(app_client, txn_client, ctx, monkeypatch):
649644

650645
# Test POST search with sortby
651646
resp = await app_client.post(
652-
"/collections",
647+
"/collections-search",
653648
json={"sortby": [{"field": "id", "direction": "desc"}]},
654649
)
655650
assert resp.status_code == 200
@@ -669,7 +664,7 @@ async def test_collections_post(app_client, txn_client, ctx, monkeypatch):
669664

670665
# Test POST search with fields
671666
resp = await app_client.post(
672-
"/collections",
667+
"/collections-search",
673668
json={"fields": {"exclude": ["stac_version"]}},
674669
)
675670
assert resp.status_code == 200
@@ -683,9 +678,3 @@ async def test_collections_post(app_client, txn_client, ctx, monkeypatch):
683678
# Check that stac_version is excluded from the collections
684679
for collection in test_collections:
685680
assert "stac_version" not in collection
686-
687-
# Restore the original environment variable value
688-
if original_value is not None:
689-
monkeypatch.setenv("ENABLE_TRANSACTIONS_EXTENSIONS", original_value)
690-
else:
691-
monkeypatch.delenv("ENABLE_TRANSACTIONS_EXTENSIONS", raising=False)

0 commit comments

Comments
 (0)