Skip to content

Commit 0ff6cc0

Browse files
committed
add numMatched, numReturned
1 parent f34fabc commit 0ff6cc0

File tree

4 files changed

+159
-15
lines changed

4 files changed

+159
-15
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ async def all_collections(
350350
if datetime:
351351
parsed_datetime = format_datetime_range(date_str=datetime)
352352

353-
collections, next_token = await self.database.get_all_collections(
353+
collections, next_token, maybe_count = await self.database.get_all_collections(
354354
token=token,
355355
limit=limit,
356356
request=request,
@@ -384,7 +384,12 @@ async def all_collections(
384384
next_link = PagingLinks(next=next_token, request=request).link_next()
385385
links.append(next_link)
386386

387-
return stac_types.Collections(collections=filtered_collections, links=links)
387+
return stac_types.Collections(
388+
collections=filtered_collections,
389+
links=links,
390+
numberMatched=maybe_count,
391+
numberReturned=len(filtered_collections),
392+
)
388393

389394
async def get_collection(
390395
self, collection_id: str, **kwargs

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ async def get_all_collections(
179179
filter: Optional[Dict[str, Any]] = None,
180180
query: Optional[Dict[str, Dict[str, Any]]] = None,
181181
datetime: Optional[str] = None,
182-
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
182+
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[int]]:
183183
"""Retrieve a list of collections from Elasticsearch, supporting pagination.
184184
185185
Args:
@@ -317,12 +317,30 @@ async def get_all_collections(
317317
else {"bool": {"must": query_parts}}
318318
)
319319

320-
# Execute the search
321-
response = await self.client.search(
322-
index=COLLECTIONS_INDEX,
323-
body=body,
320+
# Create a copy of the body for count query (without pagination and sorting)
321+
count_body = body.copy()
322+
if "search_after" in count_body:
323+
del count_body["search_after"]
324+
count_body["size"] = 0
325+
326+
# Create async tasks for both search and count
327+
search_task = asyncio.create_task(
328+
self.client.search(
329+
index=COLLECTIONS_INDEX,
330+
body=body,
331+
)
332+
)
333+
334+
count_task = asyncio.create_task(
335+
self.client.count(
336+
index=COLLECTIONS_INDEX,
337+
body={"query": body.get("query", {"match_all": {}})},
338+
)
324339
)
325340

341+
# Wait for search task to complete
342+
response = await search_task
343+
326344
hits = response["hits"]["hits"]
327345
collections = [
328346
self.collection_serializer.db_to_stac(
@@ -337,7 +355,22 @@ async def get_all_collections(
337355
if next_token_values:
338356
next_token = next_token_values[0]
339357

340-
return collections, next_token
358+
# Get the total count of collections
359+
matched = (
360+
response["hits"]["total"]["value"]
361+
if response["hits"]["total"]["relation"] == "eq"
362+
else None
363+
)
364+
365+
# If count task is done, use its result
366+
if count_task.done():
367+
try:
368+
matched = count_task.result().get("count")
369+
except Exception as e:
370+
logger = logging.getLogger(__name__)
371+
logger.error(f"Count task failed: {e}")
372+
373+
return collections, next_token, matched
341374

342375
@staticmethod
343376
def _apply_collection_datetime_filter(

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,9 @@ async def get_all_collections(
163163
filter: Optional[Dict[str, Any]] = None,
164164
query: Optional[Dict[str, Dict[str, Any]]] = None,
165165
datetime: Optional[str] = None,
166-
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
167-
"""Retrieve a list of collections from OpenSearch, supporting pagination.
166+
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[int]]:
167+
"""
168+
Retrieve a list of collections from Opensearch, supporting pagination.
168169
169170
Args:
170171
token (Optional[str]): The pagination token.
@@ -299,12 +300,30 @@ async def get_all_collections(
299300
else {"bool": {"must": query_parts}}
300301
)
301302

302-
# Execute the search
303-
response = await self.client.search(
304-
index=COLLECTIONS_INDEX,
305-
body=body,
303+
# Create a copy of the body for count query (without pagination and sorting)
304+
count_body = body.copy()
305+
if "search_after" in count_body:
306+
del count_body["search_after"]
307+
count_body["size"] = 0
308+
309+
# Create async tasks for both search and count
310+
search_task = asyncio.create_task(
311+
self.client.search(
312+
index=COLLECTIONS_INDEX,
313+
body=body,
314+
)
315+
)
316+
317+
count_task = asyncio.create_task(
318+
self.client.count(
319+
index=COLLECTIONS_INDEX,
320+
body={"query": body.get("query", {"match_all": {}})},
321+
)
306322
)
307323

324+
# Wait for search task to complete
325+
response = await search_task
326+
308327
hits = response["hits"]["hits"]
309328
collections = [
310329
self.collection_serializer.db_to_stac(
@@ -319,7 +338,22 @@ async def get_all_collections(
319338
if next_token_values:
320339
next_token = next_token_values[0]
321340

322-
return collections, next_token
341+
# Get the total count of collections
342+
matched = (
343+
response["hits"]["total"]["value"]
344+
if response["hits"]["total"]["relation"] == "eq"
345+
else None
346+
)
347+
348+
# If count task is done, use its result
349+
if count_task.done():
350+
try:
351+
matched = count_task.result().get("count")
352+
except Exception as e:
353+
logger = logging.getLogger(__name__)
354+
logger.error(f"Count task failed: {e}")
355+
356+
return collections, next_token, matched
323357

324358
async def get_one_item(self, collection_id: str, item_id: str) -> Dict:
325359
"""Retrieve a single item from the database.

stac_fastapi/tests/api/test_api_search_collections.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ async def test_collections_query_extension(app_client, txn_client, ctx):
439439
assert f"{test_prefix}-modis" in found_ids
440440

441441

442+
@pytest.mark.asyncio
442443
async def test_collections_datetime_filter(app_client, load_test_data, txn_client):
443444
"""Test filtering collections by datetime."""
444445
# Create a test collection with a specific temporal extent
@@ -526,3 +527,74 @@ async def test_collections_datetime_filter(app_client, load_test_data, txn_clien
526527
found_collections = [c for c in resp_json["collections"] if c["id"] == test_collection_id]
527528
assert len(found_collections) == 1, f"Expected to find collection {test_collection_id} with open-ended past range to a date within its range"
528529
"""
530+
531+
532+
@pytest.mark.asyncio
533+
async def test_collections_number_matched_returned(app_client, txn_client, ctx):
534+
"""Verify GET /collections returns correct numberMatched and numberReturned values."""
535+
# Create multiple collections with different ids
536+
base_collection = ctx.collection
537+
538+
# Create collections with ids in a specific order to test pagination
539+
# Use unique prefixes to avoid conflicts between tests
540+
test_prefix = f"count-{uuid.uuid4().hex[:8]}"
541+
collection_ids = [f"{test_prefix}-{i}" for i in range(10)]
542+
543+
for i, coll_id in enumerate(collection_ids):
544+
test_collection = base_collection.copy()
545+
test_collection["id"] = coll_id
546+
test_collection["title"] = f"Test Collection {i}"
547+
await create_collection(txn_client, test_collection)
548+
549+
await refresh_indices(txn_client)
550+
551+
# Test with limit=5
552+
resp = await app_client.get(
553+
"/collections",
554+
params=[("limit", "5")],
555+
)
556+
assert resp.status_code == 200
557+
resp_json = resp.json()
558+
559+
# Filter collections to only include the ones we created for this test
560+
test_collections = [
561+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
562+
]
563+
564+
# Should return 5 collections
565+
assert len(test_collections) == 5
566+
567+
# Check that numberReturned matches the number of collections returned
568+
assert resp_json["numberReturned"] == len(resp_json["collections"])
569+
570+
# Check that numberMatched is greater than or equal to numberReturned
571+
# (since there might be other collections in the database)
572+
assert resp_json["numberMatched"] >= resp_json["numberReturned"]
573+
574+
# Check that numberMatched includes at least all our test collections
575+
assert resp_json["numberMatched"] >= len(collection_ids)
576+
577+
# Now test with a query that should match only some collections
578+
query = {"id": {"eq": f"{test_prefix}-1"}}
579+
resp = await app_client.get(
580+
"/collections",
581+
params=[("query", json.dumps(query))],
582+
)
583+
assert resp.status_code == 200
584+
resp_json = resp.json()
585+
586+
# Filter collections to only include the ones we created for this test
587+
test_collections = [
588+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
589+
]
590+
591+
# Should return only 1 collection
592+
assert len(test_collections) == 1
593+
assert test_collections[0]["id"] == f"{test_prefix}-1"
594+
595+
# Check that numberReturned matches the number of collections returned
596+
assert resp_json["numberReturned"] == len(resp_json["collections"])
597+
598+
# Check that numberMatched matches the number of collections that match the query
599+
# (should be 1 in this case)
600+
assert resp_json["numberMatched"] >= 1

0 commit comments

Comments
 (0)