Skip to content

Commit 7ecd26f

Browse files
committed
test sort collections
1 parent df6a32d commit 7ecd26f

File tree

4 files changed

+114
-13
lines changed

4 files changed

+114
-13
lines changed

stac_fastapi/core/stac_fastapi/core/base_database_logic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Base database logic."""
22

33
import abc
4-
from typing import Any, Dict, Iterable, List, Optional
4+
from typing import Any, Dict, Iterable, List, Optional, Tuple
55

66

77
class BaseDatabaseLogic(abc.ABC):
@@ -14,9 +14,23 @@ class BaseDatabaseLogic(abc.ABC):
1414

1515
@abc.abstractmethod
1616
async def get_all_collections(
17-
self, token: Optional[str], limit: int
18-
) -> Iterable[Dict[str, Any]]:
19-
"""Retrieve a list of all collections from the database."""
17+
self,
18+
token: Optional[str],
19+
limit: int,
20+
request: Any = None,
21+
sort: Optional[List[Dict[str, Any]]] = None,
22+
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
23+
"""Retrieve a list of collections from the database, supporting pagination.
24+
25+
Args:
26+
token (Optional[str]): The pagination token.
27+
limit (int): The number of results to return.
28+
request (Any, optional): The FastAPI request object. Defaults to None.
29+
sort (Optional[List[Dict[str, Any]]], optional): Optional sort parameter. Defaults to None.
30+
31+
Returns:
32+
A tuple of (collections, next pagination token if any).
33+
"""
2034
pass
2135

2236
@abc.abstractmethod

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,7 @@ async def get_all_collections(
187187
Returns:
188188
A tuple of (collections, next pagination token if any).
189189
"""
190-
search_after = None
191-
if token:
192-
search_after = [token]
190+
search_after = token
193191

194192
formatted_sort = None
195193
if sort:
@@ -202,13 +200,22 @@ async def get_all_collections(
202200
# Always include id as a secondary sort to ensure consistent pagination
203201
formatted_sort.setdefault("id", {"order": "asc"})
204202

203+
# Use a collections-specific default sort that doesn't rely on properties.datetime
204+
collections_default_sort = {"id": {"order": "asc"}}
205+
206+
# Build the search body step by step to avoid type errors
207+
body = {
208+
"sort": formatted_sort or collections_default_sort,
209+
"size": limit,
210+
}
211+
212+
# Only add search_after if we have a token
213+
if search_after is not None:
214+
body["search_after"] = search_after # type: ignore
215+
205216
response = await self.client.search(
206217
index=COLLECTIONS_INDEX,
207-
body={
208-
"sort": formatted_sort or DEFAULT_SORT,
209-
"size": limit,
210-
**({"search_after": search_after} if search_after is not None else {}),
211-
},
218+
body=body,
212219
)
213220

214221
hits = response["hits"]["hits"]

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ async def get_all_collections(
171171
Returns:
172172
A tuple of (collections, next pagination token if any).
173173
"""
174+
collections_default_sort = [{"id": {"order": "asc"}}]
174175
formatted_sort = []
175176
if sort:
176177
for item in sort:
@@ -182,7 +183,7 @@ async def get_all_collections(
182183
if not any("id" in item for item in formatted_sort):
183184
formatted_sort.append({"id": {"order": "asc"}})
184185
else:
185-
formatted_sort = [{"id": {"order": "asc"}}]
186+
formatted_sort = collections_default_sort
186187

187188
search_body = {
188189
"sort": formatted_sort,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import uuid
2+
3+
import pytest
4+
5+
from ..conftest import create_collection
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_collections_sort_id_asc(app_client, txn_client, load_test_data):
10+
"""Verify GET /collections honors ascending sort on id."""
11+
# Create multiple collections with different ids
12+
base_collection = load_test_data("test_collection.json")
13+
14+
# Create collections with ids in a specific order to test sorting
15+
# Use unique prefixes to avoid conflicts between tests
16+
test_prefix = f"asc-{uuid.uuid4().hex[:8]}"
17+
collection_ids = [f"{test_prefix}-c", f"{test_prefix}-a", f"{test_prefix}-b"]
18+
19+
for i, coll_id in enumerate(collection_ids):
20+
test_collection = base_collection.copy()
21+
test_collection["id"] = coll_id
22+
test_collection["title"] = f"Test Collection {i}"
23+
await create_collection(txn_client, test_collection)
24+
25+
# Test ascending sort by id
26+
resp = await app_client.get(
27+
"/collections",
28+
params=[("sortby", "+id")],
29+
)
30+
assert resp.status_code == 200
31+
resp_json = resp.json()
32+
33+
# Filter collections to only include the ones we created for this test
34+
test_collections = [
35+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
36+
]
37+
38+
# Collections should be sorted alphabetically by id
39+
sorted_ids = sorted(collection_ids)
40+
assert len(test_collections) == len(collection_ids)
41+
for i, expected_id in enumerate(sorted_ids):
42+
assert test_collections[i]["id"] == expected_id
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_collections_sort_id_desc(app_client, txn_client, load_test_data):
47+
"""Verify GET /collections honors descending sort on id."""
48+
# Create multiple collections with different ids
49+
base_collection = load_test_data("test_collection.json")
50+
51+
# Create collections with ids in a specific order to test sorting
52+
# Use unique prefixes to avoid conflicts between tests
53+
test_prefix = f"desc-{uuid.uuid4().hex[:8]}"
54+
collection_ids = [f"{test_prefix}-c", f"{test_prefix}-a", f"{test_prefix}-b"]
55+
56+
for i, coll_id in enumerate(collection_ids):
57+
test_collection = base_collection.copy()
58+
test_collection["id"] = coll_id
59+
test_collection["title"] = f"Test Collection {i}"
60+
await create_collection(txn_client, test_collection)
61+
62+
# Test descending sort by id
63+
resp = await app_client.get(
64+
"/collections",
65+
params=[("sortby", "-id")],
66+
)
67+
assert resp.status_code == 200
68+
resp_json = resp.json()
69+
70+
# Filter collections to only include the ones we created for this test
71+
test_collections = [
72+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
73+
]
74+
75+
# Collections should be sorted in reverse alphabetical order by id
76+
sorted_ids = sorted(collection_ids, reverse=True)
77+
assert len(test_collections) == len(collection_ids)
78+
for i, expected_id in enumerate(sorted_ids):
79+
assert test_collections[i]["id"] == expected_id

0 commit comments

Comments
 (0)