Skip to content

Commit c8902d6

Browse files
committed
update stacql filter
1 parent 24142a6 commit c8902d6

File tree

2 files changed

+63
-67
lines changed

2 files changed

+63
-67
lines changed

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -275,38 +275,23 @@ async def get_all_collections(
275275
# Apply query extension if provided
276276
if query:
277277
try:
278+
# First create a search object to apply filters
279+
search = Search(index=COLLECTIONS_INDEX)
280+
278281
# Process each field and operator in the query
279282
for field_name, expr in query.items():
280283
for op, value in expr.items():
281-
# Handle different operators
282-
if op == "eq":
283-
# Equality operator
284-
# Use different query types based on field name
285-
if field_name in ["title", "description"]:
286-
# For text fields, use match_phrase for exact phrase matching
287-
query_part = {"match_phrase": {field_name: value}}
288-
else:
289-
# For other fields, use term query for exact matching
290-
query_part = {"term": {field_name: value}}
291-
query_parts.append(query_part)
292-
elif op == "neq":
293-
# Not equal operator
294-
query_part = {
295-
"bool": {"must_not": [{"term": {field_name: value}}]}
296-
}
297-
query_parts.append(query_part)
298-
elif op in ["lt", "lte", "gt", "gte"]:
299-
# Range operators
300-
query_parts.append({"range": {field_name: {op: value}}})
301-
elif op == "in":
302-
# In operator (value should be a list)
303-
if isinstance(value, list):
304-
query_parts.append({"terms": {field_name: value}})
305-
else:
306-
query_parts.append({"term": {field_name: value}})
307-
elif op == "contains":
308-
# Contains operator for arrays
309-
query_parts.append({"term": {field_name: value}})
284+
# For collections, we don't need to prefix with 'properties__'
285+
field = field_name
286+
# Apply the filter using apply_stacql_filter
287+
search = self.apply_stacql_filter(
288+
search=search, op=op, field=field, value=value
289+
)
290+
291+
# Convert the search object to a query dict and add it to query_parts
292+
search_dict = search.to_dict()
293+
if "query" in search_dict:
294+
query_parts.append(search_dict["query"])
310295

311296
except Exception as e:
312297
logger = logging.getLogger(__name__)
@@ -607,18 +592,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):
607592
608593
Args:
609594
search (Search): The search object to apply the filter to.
610-
op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal),
611-
'lt' (less than), or 'lte' (less than or equal).
595+
op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than),
596+
'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal).
612597
field (str): The field to perform the comparison on.
613598
value (float): The value to compare the field against.
614599
615600
Returns:
616601
search (Search): The search object with the specified filter applied.
617602
"""
618-
if op != "eq":
603+
if op == "eq":
604+
search = search.filter("term", **{field: value})
605+
elif op == "ne" or op == "neq":
606+
# For not equal, use a bool query with must_not
607+
search = search.exclude("term", **{field: value})
608+
elif op in ["gt", "gte", "lt", "lte"]:
609+
# For range operators
619610
key_filter = {field: {op: value}}
620611
search = search.filter(Q("range", **key_filter))
621-
else:
612+
elif op == "in":
613+
# For in operator (value should be a list)
614+
if isinstance(value, list):
615+
search = search.filter("terms", **{field: value})
616+
else:
617+
search = search.filter("term", **{field: value})
618+
elif op == "contains":
619+
# For contains operator (for arrays)
622620
search = search.filter("term", **{field: value})
623621

624622
return search

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -259,38 +259,23 @@ async def get_all_collections(
259259
# Apply query extension if provided
260260
if query:
261261
try:
262+
# First create a search object to apply filters
263+
search = Search(index=COLLECTIONS_INDEX)
264+
262265
# Process each field and operator in the query
263266
for field_name, expr in query.items():
264267
for op, value in expr.items():
265-
# Handle different operators
266-
if op == "eq":
267-
# Equality operator
268-
# Use different query types based on field name
269-
if field_name in ["title", "description"]:
270-
# For text fields, use match_phrase for exact phrase matching
271-
query_part = {"match_phrase": {field_name: value}}
272-
else:
273-
# For other fields, use term query for exact matching
274-
query_part = {"term": {field_name: value}}
275-
query_parts.append(query_part)
276-
elif op == "neq":
277-
# Not equal operator
278-
query_part = {
279-
"bool": {"must_not": [{"term": {field_name: value}}]}
280-
}
281-
query_parts.append(query_part)
282-
elif op in ["lt", "lte", "gt", "gte"]:
283-
# Range operators
284-
query_parts.append({"range": {field_name: {op: value}}})
285-
elif op == "in":
286-
# In operator (value should be a list)
287-
if isinstance(value, list):
288-
query_parts.append({"terms": {field_name: value}})
289-
else:
290-
query_parts.append({"term": {field_name: value}})
291-
elif op == "contains":
292-
# Contains operator for arrays
293-
query_parts.append({"term": {field_name: value}})
268+
# For collections, we don't need to prefix with 'properties__'
269+
field = field_name
270+
# Apply the filter using apply_stacql_filter
271+
search = self.apply_stacql_filter(
272+
search=search, op=op, field=field, value=value
273+
)
274+
275+
# Convert the search object to a query dict and add it to query_parts
276+
search_dict = search.to_dict()
277+
if "query" in search_dict:
278+
query_parts.append(search_dict["query"])
294279
except Exception as e:
295280
logger = logging.getLogger(__name__)
296281
logger.error(f"Error converting query to Elasticsearch: {e}")
@@ -608,18 +593,31 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):
608593
609594
Args:
610595
search (Search): The search object to apply the filter to.
611-
op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal),
612-
'lt' (less than), or 'lte' (less than or equal).
596+
op (str): The comparison operator to use. Can be 'eq' (equal), 'ne'/'neq' (not equal), 'gt' (greater than),
597+
'gte' (greater than or equal), 'lt' (less than), or 'lte' (less than or equal).
613598
field (str): The field to perform the comparison on.
614599
value (float): The value to compare the field against.
615600
616601
Returns:
617602
search (Search): The search object with the specified filter applied.
618603
"""
619-
if op != "eq":
620-
key_filter = {field: {f"{op}": value}}
604+
if op == "eq":
605+
search = search.filter("term", **{field: value})
606+
elif op == "ne" or op == "neq":
607+
# For not equal, use a bool query with must_not
608+
search = search.exclude("term", **{field: value})
609+
elif op in ["gt", "gte", "lt", "lte"]:
610+
# For range operators
611+
key_filter = {field: {op: value}}
621612
search = search.filter(Q("range", **key_filter))
622-
else:
613+
elif op == "in":
614+
# For in operator (value should be a list)
615+
if isinstance(value, list):
616+
search = search.filter("terms", **{field: value})
617+
else:
618+
search = search.filter("term", **{field: value})
619+
elif op == "contains":
620+
# For contains operator (for arrays)
623621
search = search.filter("term", **{field: value})
624622

625623
return search

0 commit comments

Comments
 (0)