Skip to content

Commit 038db36

Browse files
committed
query ext scratch
1 parent ab62cd8 commit 038db36

File tree

3 files changed

+203
-14
lines changed

3 files changed

+203
-14
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ async def all_collections(
231231
filter_expr: Optional[str] = None,
232232
filter_lang: Optional[str] = None,
233233
q: Optional[Union[str, List[str]]] = None,
234+
query: Optional[str] = None,
234235
**kwargs,
235236
) -> stac_types.Collections:
236237
"""Read all collections from the database.
@@ -239,6 +240,7 @@ async def all_collections(
239240
fields (Optional[List[str]]): Fields to include or exclude from the results.
240241
sortby (Optional[str]): Sorting options for the results.
241242
filter_expr (Optional[str]): Structured filter expression in CQL2 JSON or CQL2-text format.
243+
query (Optional[str]): Legacy query parameter (deprecated).
242244
filter_lang (Optional[str]): Must be 'cql2-json' or 'cql2-text' if specified, other values will result in an error.
243245
q (Optional[Union[str, List[str]]]): Free text search terms.
244246
**kwargs: Keyword arguments from the request.
@@ -280,19 +282,32 @@ async def all_collections(
280282
if q is not None:
281283
q_list = [q] if isinstance(q, str) else q
282284

285+
# Parse the query parameter if provided
286+
parsed_query = None
287+
if query is not None:
288+
try:
289+
import orjson
290+
291+
parsed_query = orjson.loads(query)
292+
except Exception as e:
293+
raise HTTPException(
294+
status_code=400, detail=f"Invalid query parameter: {e}"
295+
)
296+
283297
# Parse the filter parameter if provided
284298
parsed_filter = None
285299
if filter_expr is not None:
286300
try:
287-
# Check if filter_lang is specified and not one of the supported formats
301+
# Only raise an error for explicitly unsupported filter languages
302+
# Allow None, cql2-json, and cql2-text (we'll treat it as JSON)
288303
if filter_lang is not None and filter_lang not in [
289304
"cql2-json",
290305
"cql2-text",
291306
]:
292307
# Raise an error for unsupported filter languages
293308
raise HTTPException(
294309
status_code=400,
295-
detail=f"Input should be 'cql2-json' or 'cql2-text' for collections. Got '{filter_lang}'.",
310+
detail=f"Only 'cql2-json' and 'cql2-text' filter languages are supported for collections. Got '{filter_lang}'.",
296311
)
297312

298313
# Handle different filter formats
@@ -335,6 +350,7 @@ async def all_collections(
335350
sort=sort,
336351
q=q_list,
337352
filter=parsed_filter,
353+
query=parsed_query,
338354
)
339355

340356
# Apply field filtering if fields parameter was provided

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ async def get_all_collections(
177177
sort: Optional[List[Dict[str, Any]]] = None,
178178
q: Optional[List[str]] = None,
179179
filter: Optional[Dict[str, Any]] = None,
180+
query: Optional[Dict[str, Dict[str, Any]]] = None,
180181
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
181182
"""Retrieve a list of collections from Elasticsearch, supporting pagination.
182183
@@ -186,7 +187,8 @@ async def get_all_collections(
186187
request (Request): The FastAPI request object.
187188
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
188189
q (Optional[List[str]]): Free text search terms.
189-
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
190+
filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format.
191+
query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
190192
191193
Returns:
192194
A tuple of (collections, next pagination token if any).
@@ -270,7 +272,50 @@ async def get_all_collections(
270272
es_query = filter_module.to_es(await self.get_queryables_mapping(), filter)
271273
query_parts.append(es_query)
272274

273-
# Combine all query parts with AND logic
275+
# Apply query extension if provided
276+
if query:
277+
try:
278+
# Process each field and operator in the query
279+
for field_name, expr in query.items():
280+
for op, value in expr.items():
281+
# Handle different operators
282+
if op == "eq":
283+
# Equality operator
284+
# Use different query types based on field name
285+
if field_name in ["title", "description"]:
286+
# For text fields, use match_phrase for exact phrase matching
287+
query_part = {"match_phrase": {field_name: value}}
288+
else:
289+
# For other fields, use term query for exact matching
290+
query_part = {"term": {field_name: value}}
291+
query_parts.append(query_part)
292+
elif op == "neq":
293+
# Not equal operator
294+
query_part = {
295+
"bool": {"must_not": [{"term": {field_name: value}}]}
296+
}
297+
print(f"Adding neq query part: {query_part}")
298+
query_parts.append(query_part)
299+
elif op in ["lt", "lte", "gt", "gte"]:
300+
# Range operators
301+
query_parts.append({"range": {field_name: {op: value}}})
302+
elif op == "in":
303+
# In operator (value should be a list)
304+
if isinstance(value, list):
305+
query_parts.append({"terms": {field_name: value}})
306+
else:
307+
query_parts.append({"term": {field_name: value}})
308+
elif op == "contains":
309+
# Contains operator for arrays
310+
query_parts.append({"term": {field_name: value}})
311+
except Exception as e:
312+
logger = logging.getLogger(__name__)
313+
logger.error(f"Error converting query to Elasticsearch: {e}")
314+
# If there's an error, add a query that matches nothing
315+
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
316+
raise
317+
318+
# Combine all query parts with AND logic if there are multiple
274319
if query_parts:
275320
body["query"] = (
276321
query_parts[0]

stac_fastapi/tests/api/test_api_search_collections.py

Lines changed: 138 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88

99
@pytest.mark.asyncio
10-
async def test_collections_sort_id_asc(app_client, txn_client, load_test_data):
10+
async def test_collections_sort_id_asc(app_client, txn_client, ctx):
1111
"""Verify GET /collections honors ascending sort on id."""
1212
# Create multiple collections with different ids
13-
base_collection = load_test_data("test_collection.json")
13+
base_collection = ctx.collection
1414

1515
# Create collections with ids in a specific order to test sorting
1616
# Use unique prefixes to avoid conflicts between tests
@@ -23,6 +23,8 @@ async def test_collections_sort_id_asc(app_client, txn_client, load_test_data):
2323
test_collection["title"] = f"Test Collection {i}"
2424
await create_collection(txn_client, test_collection)
2525

26+
await refresh_indices(txn_client)
27+
2628
# Test ascending sort by id
2729
resp = await app_client.get(
2830
"/collections",
@@ -44,10 +46,10 @@ async def test_collections_sort_id_asc(app_client, txn_client, load_test_data):
4446

4547

4648
@pytest.mark.asyncio
47-
async def test_collections_sort_id_desc(app_client, txn_client, load_test_data):
49+
async def test_collections_sort_id_desc(app_client, txn_client, ctx):
4850
"""Verify GET /collections honors descending sort on id."""
4951
# Create multiple collections with different ids
50-
base_collection = load_test_data("test_collection.json")
52+
base_collection = ctx.collection
5153

5254
# Create collections with ids in a specific order to test sorting
5355
# Use unique prefixes to avoid conflicts between tests
@@ -60,6 +62,8 @@ async def test_collections_sort_id_desc(app_client, txn_client, load_test_data):
6062
test_collection["title"] = f"Test Collection {i}"
6163
await create_collection(txn_client, test_collection)
6264

65+
await refresh_indices(txn_client)
66+
6367
# Test descending sort by id
6468
resp = await app_client.get(
6569
"/collections",
@@ -81,10 +85,10 @@ async def test_collections_sort_id_desc(app_client, txn_client, load_test_data):
8185

8286

8387
@pytest.mark.asyncio
84-
async def test_collections_fields(app_client, txn_client, load_test_data):
88+
async def test_collections_fields(app_client, txn_client, ctx):
8589
"""Verify GET /collections honors the fields parameter."""
8690
# Create multiple collections with different ids
87-
base_collection = load_test_data("test_collection.json")
91+
base_collection = ctx.collection
8892

8993
# Create collections with ids in a specific order to test fields
9094
# Use unique prefixes to avoid conflicts between tests
@@ -98,6 +102,8 @@ async def test_collections_fields(app_client, txn_client, load_test_data):
98102
test_collection["description"] = f"Description for collection {i}"
99103
await create_collection(txn_client, test_collection)
100104

105+
await refresh_indices(txn_client)
106+
101107
# Test include fields parameter
102108
resp = await app_client.get(
103109
"/collections",
@@ -156,10 +162,10 @@ async def test_collections_fields(app_client, txn_client, load_test_data):
156162

157163

158164
@pytest.mark.asyncio
159-
async def test_collections_free_text_search_get(app_client, txn_client, load_test_data):
165+
async def test_collections_free_text_search_get(app_client, txn_client, ctx):
160166
"""Verify GET /collections honors the q parameter for free text search."""
161167
# Create multiple collections with different content
162-
base_collection = load_test_data("test_collection.json")
168+
base_collection = ctx.collection
163169

164170
# Use unique prefixes to avoid conflicts between tests
165171
test_prefix = f"q-get-{uuid.uuid4().hex[:8]}"
@@ -193,6 +199,8 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes
193199
test_collection["summaries"] = coll["summaries"]
194200
await create_collection(txn_client, test_collection)
195201

202+
await refresh_indices(txn_client)
203+
196204
# Test free text search for "sentinel"
197205
resp = await app_client.get(
198206
"/collections",
@@ -229,10 +237,10 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes
229237

230238

231239
@pytest.mark.asyncio
232-
async def test_collections_filter_search(app_client, txn_client, load_test_data):
240+
async def test_collections_filter_search(app_client, txn_client, ctx):
233241
"""Verify GET /collections honors the filter parameter for structured search."""
234242
# Create multiple collections with different content
235-
base_collection = load_test_data("test_collection.json")
243+
base_collection = ctx.collection
236244

237245
# Use unique prefixes to avoid conflicts between tests
238246
test_prefix = f"filter-{uuid.uuid4().hex[:8]}"
@@ -313,3 +321,123 @@ async def test_collections_filter_search(app_client, txn_client, load_test_data)
313321
assert (
314322
len(found_collections) >= 1
315323
), f"Expected at least 1 collection with ID {test_collection_id} using LIKE filter"
324+
325+
326+
@pytest.mark.asyncio
327+
async def test_collections_query_extension(app_client, txn_client, ctx):
328+
"""Verify GET /collections honors the query extension."""
329+
# Create multiple collections with different content
330+
base_collection = ctx.collection
331+
# Use unique prefixes to avoid conflicts between tests
332+
test_prefix = f"query-ext-{uuid.uuid4().hex[:8]}"
333+
334+
# Create collections with different content to test query extension
335+
test_collections = [
336+
{
337+
"id": f"{test_prefix}-sentinel",
338+
"title": "Sentinel-2 Collection",
339+
"description": "Collection of Sentinel-2 data",
340+
"summaries": {"platform": ["sentinel-2a", "sentinel-2b"]},
341+
},
342+
{
343+
"id": f"{test_prefix}-landsat",
344+
"title": "Landsat Collection",
345+
"description": "Collection of Landsat data",
346+
"summaries": {"platform": ["landsat-8", "landsat-9"]},
347+
},
348+
{
349+
"id": f"{test_prefix}-modis",
350+
"title": "MODIS Collection",
351+
"description": "Collection of MODIS data",
352+
"summaries": {"platform": ["terra", "aqua"]},
353+
},
354+
]
355+
356+
for i, coll in enumerate(test_collections):
357+
test_collection = base_collection.copy()
358+
test_collection["id"] = coll["id"]
359+
test_collection["title"] = coll["title"]
360+
test_collection["description"] = coll["description"]
361+
test_collection["summaries"] = coll["summaries"]
362+
await create_collection(txn_client, test_collection)
363+
364+
await refresh_indices(txn_client)
365+
366+
# Test query extension for exact ID match
367+
import json
368+
369+
# Use the exact ID that was created
370+
sentinel_id = f"{test_prefix}-sentinel"
371+
print(f"Searching for ID: {sentinel_id}")
372+
373+
query = {"id": {"eq": sentinel_id}}
374+
375+
resp = await app_client.get(
376+
"/collections",
377+
params=[("query", json.dumps(query))],
378+
)
379+
assert resp.status_code == 200
380+
resp_json = resp.json()
381+
382+
# Filter collections to only include the ones we created for this test
383+
found_collections = [
384+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
385+
]
386+
387+
# Should only find the sentinel collection
388+
assert len(found_collections) == 1
389+
assert found_collections[0]["id"] == f"{test_prefix}-sentinel"
390+
391+
# Test query extension with equal operator on ID
392+
query = {"id": {"eq": f"{test_prefix}-sentinel"}}
393+
394+
resp = await app_client.get(
395+
"/collections",
396+
params=[("query", json.dumps(query))],
397+
)
398+
assert resp.status_code == 200
399+
resp_json = resp.json()
400+
401+
# Filter collections to only include the ones we created for this test
402+
found_collections = [
403+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
404+
]
405+
found_ids = [c["id"] for c in found_collections]
406+
407+
# Should find landsat and modis collections but not sentinel
408+
assert len(found_collections) == 1
409+
assert f"{test_prefix}-sentinel" in found_ids
410+
assert f"{test_prefix}-landsat" not in found_ids
411+
assert f"{test_prefix}-modis" not in found_ids
412+
413+
# Test query extension with not-equal operator on ID
414+
query = {"id": {"neq": f"{test_prefix}-sentinel"}}
415+
416+
print(f"\nTesting neq query: {query}")
417+
print(f"JSON query: {json.dumps(query)}")
418+
419+
resp = await app_client.get(
420+
"/collections",
421+
params=[("query", json.dumps(query))],
422+
)
423+
print(f"Response status: {resp.status_code}")
424+
assert resp.status_code == 200
425+
resp_json = resp.json()
426+
print(f"Response JSON keys: {resp_json.keys()}")
427+
print(f"Number of collections in response: {len(resp_json.get('collections', []))}")
428+
429+
# Print all collection IDs in the response
430+
all_ids = [c["id"] for c in resp_json.get("collections", [])]
431+
print(f"All collection IDs in response: {all_ids}")
432+
433+
# Filter collections to only include the ones we created for this test
434+
found_collections = [
435+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
436+
]
437+
found_ids = [c["id"] for c in found_collections]
438+
439+
# Should find landsat and modis collections but not sentinel
440+
assert len(found_collections) == 2
441+
assert f"{test_prefix}-sentinel" not in found_ids
442+
assert f"{test_prefix}-landsat" in found_ids
443+
assert f"{test_prefix}-modis" in found_ids

0 commit comments

Comments
 (0)