Skip to content

Commit f9c5aa9

Browse files
committed
post sort
1 parent 3db9018 commit f9c5aa9

File tree

2 files changed

+58
-28
lines changed

2 files changed

+58
-28
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ async def all_collections(
229229
datetime: Optional[str] = None,
230230
limit: Optional[int] = None,
231231
fields: Optional[List[str]] = None,
232-
sortby: Optional[str] = None,
232+
sortby: Optional[Union[str, List[str]]] = None,
233233
filter_expr: Optional[str] = None,
234234
filter_lang: Optional[str] = None,
235235
q: Optional[Union[str, List[str]]] = None,
@@ -255,8 +255,6 @@ async def all_collections(
255255
request = kwargs["request"]
256256
base_url = str(request.base_url)
257257

258-
print("fields get: ", fields)
259-
260258
limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10)))
261259

262260
token = request.query_params.get("token")
@@ -265,18 +263,10 @@ async def all_collections(
265263
includes, excludes = set(), set()
266264
if fields:
267265
for field in fields:
268-
print("Processing field:", field)
269266
if field[0] == "-":
270267
excludes.add(field[1:])
271-
print("Added to excludes:", field[1:])
272268
else:
273269
includes.add(field[1:] if field[0] in "+ " else field)
274-
print(
275-
"Added to includes:", field[1:] if field[0] in "+ " else field
276-
)
277-
print("Final includes:", includes)
278-
print("Final excludes:", excludes)
279-
print("fields get: ", fields)
280270

281271
sort = None
282272
if sortby:
@@ -293,6 +283,7 @@ async def all_collections(
293283
if parsed_sort:
294284
sort = parsed_sort
295285

286+
print("sort: ", sort)
296287
# Convert q to a list if it's a string
297288
q_list = None
298289
if q is not None:
@@ -413,6 +404,8 @@ async def post_all_collections(
413404
Returns:
414405
A Collections object containing all the collections in the database and links to various resources.
415406
"""
407+
# Set the postbody attribute on the request object for PagingLinks
408+
request.postbody = search_request.model_dump(exclude_unset=True)
416409
# Convert fields parameter from POST format to all_collections format
417410
fields = None
418411

@@ -438,16 +431,27 @@ async def post_all_collections(
438431
# Convert sortby parameter from POST format to all_collections format
439432
sortby = None
440433
if hasattr(search_request, "sortby") and search_request.sortby:
441-
sort_strings = []
434+
# Create a list of sort strings in the format expected by all_collections
435+
sortby = []
442436
for sort_item in search_request.sortby:
443-
direction = sort_item.get("direction", "asc")
444-
field = sort_item.get("field")
437+
# Handle different types of sort items
438+
if hasattr(sort_item, "field") and hasattr(sort_item, "direction"):
439+
# This is a Pydantic model with field and direction attributes
440+
field = sort_item.field
441+
direction = sort_item.direction
442+
elif isinstance(sort_item, dict):
443+
# This is a dictionary with field and direction keys
444+
field = sort_item.get("field")
445+
direction = sort_item.get("direction", "asc")
446+
else:
447+
# Skip this item if we can't extract field and direction
448+
continue
449+
445450
if field:
446-
prefix = "-" if direction.lower() == "desc" else "+"
447-
sort_strings.append(f"{prefix}{field}")
448-
# Join the sort strings into a single string
449-
if sort_strings:
450-
sortby = ",".join(sort_strings)
451+
# Create a sort string in the format expected by all_collections
452+
# e.g., "-id" for descending sort on id field
453+
prefix = "-" if direction.lower() == "desc" else ""
454+
sortby.append(f"{prefix}{field}")
451455

452456
# Pass all parameters from search_request to all_collections
453457
return await self.all_collections(

stac_fastapi/tests/api/test_api_search_collections.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,14 @@ async def test_collections_number_matched_returned(app_client, txn_client, ctx):
601601

602602

603603
@pytest.mark.asyncio
604-
async def test_collections_search_post(app_client, txn_client, ctx):
605-
"""Verify POST /collections-search endpoint works."""
604+
async def test_collections_post(app_client, txn_client, ctx, monkeypatch):
605+
"""Verify POST /collections endpoint works."""
606+
# Turn off the transaction extension to avoid conflict with collections POST endpoint
607+
import os
608+
609+
original_value = os.environ.get("ENABLE_TRANSACTIONS_EXTENSIONS")
610+
monkeypatch.setenv("ENABLE_TRANSACTIONS_EXTENSIONS", "False")
611+
606612
# Create multiple collections with different ids
607613
base_collection = ctx.collection
608614

@@ -641,10 +647,10 @@ async def test_collections_search_post(app_client, txn_client, ctx):
641647
# Check that numberMatched is greater than or equal to numberReturned
642648
assert resp_json["numberMatched"] >= resp_json["numberReturned"]
643649

644-
# Test POST search with query
650+
# Test POST search with sortby
645651
resp = await app_client.post(
646652
"/collections",
647-
json={"query": {"id": {"eq": f"{test_prefix}-1"}}},
653+
json={"sortby": [{"field": "id", "direction": "desc"}]},
648654
)
649655
assert resp.status_code == 200
650656
resp_json = resp.json()
@@ -654,12 +660,32 @@ async def test_collections_search_post(app_client, txn_client, ctx):
654660
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
655661
]
656662

657-
# Should return only 1 collection
658-
assert len(test_collections) == 1
659-
assert test_collections[0]["id"] == f"{test_prefix}-1"
663+
# Check that collections are sorted by id in descending order
664+
if len(test_collections) >= 2:
665+
assert test_collections[0]["id"] > test_collections[1]["id"]
660666

661667
# Check that numberReturned matches the number of collections returned
662668
assert resp_json["numberReturned"] == len(resp_json["collections"])
663669

664-
# Check that numberMatched matches the number of collections that match the query
665-
assert resp_json["numberMatched"] >= 1
670+
# Test POST search with fields
671+
resp = await app_client.post(
672+
"/collections",
673+
json={"fields": {"exclude": ["stac_version"]}},
674+
)
675+
assert resp.status_code == 200
676+
resp_json = resp.json()
677+
678+
# Filter collections to only include the ones we created for this test
679+
test_collections = [
680+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
681+
]
682+
683+
# Check that stac_version is excluded from the collections
684+
for collection in test_collections:
685+
assert "stac_version" not in collection
686+
687+
# Restore the original environment variable value
688+
if original_value is not None:
689+
monkeypatch.setenv("ENABLE_TRANSACTIONS_EXTENSIONS", original_value)
690+
else:
691+
monkeypatch.delenv("ENABLE_TRANSACTIONS_EXTENSIONS", raising=False)

0 commit comments

Comments
 (0)