Skip to content

Commit b585819

Browse files
committed
filter scratch
1 parent 57afb55 commit b585819

File tree

3 files changed

+185
-7
lines changed

3 files changed

+185
-7
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ async def all_collections(
228228
self,
229229
fields: Optional[List[str]] = None,
230230
sortby: Optional[str] = None,
231+
filter_expr: Optional[str] = None,
231232
q: Optional[Union[str, List[str]]] = None,
232233
**kwargs,
233234
) -> stac_types.Collections:
@@ -236,12 +237,14 @@ async def all_collections(
236237
Args:
237238
fields (Optional[List[str]]): Fields to include or exclude from the results.
238239
sortby (Optional[str]): Sorting options for the results.
239-
q (Optional[List[str]]): Free text search terms.
240+
filter_expr (Optional[str]): Structured filter in CQL2 format.
241+
q (Optional[Union[str, List[str]]]): Free text search terms.
240242
**kwargs: Keyword arguments from the request.
241243
242244
Returns:
243245
A Collections object containing all the collections in the database and links to various resources.
244246
"""
247+
print("filter: ", filter_expr)
245248
request = kwargs["request"]
246249
base_url = str(request.base_url)
247250
limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10)))
@@ -276,8 +279,25 @@ async def all_collections(
276279
if q is not None:
277280
q_list = [q] if isinstance(q, str) else q
278281

282+
# Parse the filter parameter if provided
283+
parsed_filter = None
284+
if filter_expr is not None:
285+
try:
286+
import orjson
287+
288+
parsed_filter = orjson.loads(filter_expr)
289+
except Exception as e:
290+
raise HTTPException(
291+
status_code=400, detail=f"Invalid filter parameter: {e}"
292+
)
293+
279294
collections, next_token = await self.database.get_all_collections(
280-
token=token, limit=limit, request=request, sort=sort, q=q_list
295+
token=token,
296+
limit=limit,
297+
request=request,
298+
sort=sort,
299+
q=q_list,
300+
filter=parsed_filter,
281301
)
282302

283303
# Apply field filtering if fields parameter was provided

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ async def get_all_collections(
176176
request: Request,
177177
sort: Optional[List[Dict[str, Any]]] = None,
178178
q: Optional[List[str]] = None,
179+
filter: Optional[Dict[str, Any]] = None,
179180
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
180181
"""Retrieve a list of collections from Elasticsearch, supporting pagination.
181182
@@ -185,6 +186,7 @@ async def get_all_collections(
185186
request (Request): The FastAPI request object.
186187
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
187188
q (Optional[List[str]]): Free text search terms.
189+
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
188190
189191
Returns:
190192
A tuple of (collections, next pagination token if any).
@@ -225,6 +227,9 @@ async def get_all_collections(
225227
if token:
226228
body["search_after"] = [token]
227229

230+
# Build the query part of the body
231+
query_parts = []
232+
228233
# Apply free text query if provided
229234
if q:
230235
# For collections, we want to search across all relevant fields
@@ -251,10 +256,98 @@ async def get_all_collections(
251256
}
252257
)
253258

254-
# Add the query to the body using bool query with should clauses
255-
body["query"] = {
256-
"bool": {"should": should_clauses, "minimum_should_match": 1}
257-
}
259+
# Add the free text query to the query parts
260+
query_parts.append(
261+
{"bool": {"should": should_clauses, "minimum_should_match": 1}}
262+
)
263+
264+
# Apply structured filter if provided
265+
if filter:
266+
try:
267+
# For simple direct query handling without using to_es
268+
# This is a simplified approach that handles common filter patterns
269+
if isinstance(filter, dict):
270+
# Check if this is a CQL2 filter with op and args
271+
if "op" in filter and "args" in filter:
272+
op = filter.get("op")
273+
args = filter.get("args")
274+
275+
# Handle equality operator
276+
if (
277+
op == "="
278+
and len(args) == 2
279+
and isinstance(args[0], dict)
280+
and "property" in args[0]
281+
):
282+
field = args[0]["property"]
283+
value = args[1]
284+
285+
# Handle different field types
286+
if field == "id":
287+
# Direct match on ID field
288+
query_parts.append({"term": {"id": value}})
289+
elif field == "title":
290+
# Match on title field
291+
query_parts.append({"match": {"title": value}})
292+
elif field == "description":
293+
# Match on description field
294+
query_parts.append({"match": {"description": value}})
295+
else:
296+
# For other fields, try a multi-match query
297+
query_parts.append(
298+
{
299+
"multi_match": {
300+
"query": value,
301+
"fields": [field, f"{field}.*"],
302+
"type": "best_fields",
303+
}
304+
}
305+
)
306+
307+
# Handle regex operator
308+
elif (
309+
op == "=~"
310+
and len(args) == 2
311+
and isinstance(args[0], dict)
312+
and "property" in args[0]
313+
):
314+
field = args[0]["property"]
315+
pattern = args[1].replace(".*", "*")
316+
317+
# Use wildcard query for pattern matching
318+
query_parts.append(
319+
{
320+
"wildcard": {
321+
field: {
322+
"value": pattern,
323+
"case_insensitive": True,
324+
}
325+
}
326+
}
327+
)
328+
329+
# For other operators, use a match_all query as fallback
330+
else:
331+
query_parts.append({"match_all": {}})
332+
else:
333+
# Not a valid CQL2 filter
334+
query_parts.append({"match_all": {}})
335+
else:
336+
# Not a dictionary
337+
query_parts.append({"match_all": {}})
338+
except Exception as e:
339+
logger = logging.getLogger(__name__)
340+
logger.error(f"Error converting filter to Elasticsearch: {e}")
341+
# If there's an error, add a query that matches nothing
342+
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
343+
raise
344+
345+
# Combine all query parts with AND logic if there are multiple
346+
if query_parts:
347+
if len(query_parts) == 1:
348+
body["query"] = query_parts[0]
349+
else:
350+
body["query"] = {"bool": {"must": query_parts}}
258351

259352
# Execute the search
260353
response = await self.client.search(

stac_fastapi/tests/api/test_api_search_collections.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes
163163
# Use unique prefixes to avoid conflicts between tests
164164
test_prefix = f"q-get-{uuid.uuid4().hex[:8]}"
165165

166-
# Create collections with different content to test free text search
166+
# Create collections with different content to test structured filter
167167
test_collections = [
168168
{
169169
"id": f"{test_prefix}-sentinel",
@@ -226,3 +226,68 @@ async def test_collections_free_text_search_get(app_client, txn_client, load_tes
226226
# Should only find the landsat collection
227227
assert len(found_collections) == 1
228228
assert found_collections[0]["id"] == f"{test_prefix}-modis"
229+
230+
231+
@pytest.mark.asyncio
232+
async def test_collections_filter_search(app_client, txn_client, load_test_data):
233+
"""Verify GET /collections honors the filter parameter for structured search."""
234+
# Create multiple collections with different content
235+
base_collection = load_test_data("test_collection.json")
236+
237+
# Use unique prefixes to avoid conflicts between tests
238+
test_prefix = f"filter-{uuid.uuid4().hex[:8]}"
239+
240+
# Create collections with different content to test structured filter
241+
test_collections = [
242+
{
243+
"id": f"{test_prefix}-sentinel",
244+
"title": "Sentinel-2 Collection",
245+
"description": "Collection of Sentinel-2 data",
246+
"summaries": {"platform": ["sentinel-2a", "sentinel-2b"]},
247+
},
248+
{
249+
"id": f"{test_prefix}-landsat",
250+
"title": "Landsat Collection",
251+
"description": "Collection of Landsat data",
252+
"summaries": {"platform": ["landsat-8", "landsat-9"]},
253+
},
254+
{
255+
"id": f"{test_prefix}-modis",
256+
"title": "MODIS Collection",
257+
"description": "Collection of MODIS data",
258+
"summaries": {"platform": ["terra", "aqua"]},
259+
},
260+
]
261+
262+
for i, coll in enumerate(test_collections):
263+
test_collection = base_collection.copy()
264+
test_collection["id"] = coll["id"]
265+
test_collection["title"] = coll["title"]
266+
test_collection["description"] = coll["description"]
267+
test_collection["summaries"] = coll["summaries"]
268+
await create_collection(txn_client, test_collection)
269+
270+
# Test structured filter for collections with specific ID
271+
import json
272+
273+
# Create a simple filter for exact ID match - similar to what works in Postman
274+
filter_expr = {"op": "=", "args": [{"property": "id"}, f"{test_prefix}-sentinel"]}
275+
276+
# Convert to JSON string for URL parameter
277+
filter_json = json.dumps(filter_expr)
278+
279+
# Use the exact format that works in Postman
280+
resp = await app_client.get(
281+
f"/collections?filter={filter_json}",
282+
)
283+
assert resp.status_code == 200
284+
resp_json = resp.json()
285+
286+
# Filter collections to only include the ones we created for this test
287+
found_collections = [
288+
c for c in resp_json["collections"] if c["id"].startswith(test_prefix)
289+
]
290+
291+
# Should only find the sentinel collection
292+
assert len(found_collections) == 1
293+
assert found_collections[0]["id"] == f"{test_prefix}-sentinel"

0 commit comments

Comments
 (0)