Skip to content

Commit fad578e

Browse files
committed
opensearch update
1 parent 9583f5a commit fad578e

File tree

5 files changed

+53
-8
lines changed

5 files changed

+53
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1010
### Added
1111

1212
- GET `/collections` collection search structured filter extension with support for both cql2-json and cql2-text formats. [#475](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/475)
13+
- GET `/collections` collection search query extension. [#476](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/476)
1314

1415
### Changed
1516

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
if ENABLE_COLLECTIONS_SEARCH:
122122
# Create collection search extensions
123123
collection_search_extensions = [
124-
# QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
124+
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
125125
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
126126
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
127127
CollectionSearchFilterExtension(

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ async def get_all_collections(
294294
query_part = {
295295
"bool": {"must_not": [{"term": {field_name: value}}]}
296296
}
297-
print(f"Adding neq query part: {query_part}")
298297
query_parts.append(query_part)
299298
elif op in ["lt", "lte", "gt", "gte"]:
300299
# Range operators

stac_fastapi/opensearch/stac_fastapi/opensearch/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
if ENABLE_COLLECTIONS_SEARCH:
122122
# Create collection search extensions
123123
collection_search_extensions = [
124-
# QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
124+
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
125125
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
126126
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
127127
CollectionSearchFilterExtension(
@@ -170,6 +170,7 @@
170170
post_request_model=post_request_model,
171171
landing_page_id=os.getenv("STAC_FASTAPI_LANDING_PAGE_ID", "stac-fastapi"),
172172
),
173+
"collections_get_request_model": collections_get_request_model,
173174
"search_get_request_model": create_get_request_model(search_extensions),
174175
"search_post_request_model": post_request_model,
175176
"items_get_request_model": items_get_request_model,

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,18 @@ async def get_all_collections(
161161
sort: Optional[List[Dict[str, Any]]] = None,
162162
q: Optional[List[str]] = None,
163163
filter: Optional[Dict[str, Any]] = None,
164+
query: Optional[Dict[str, Dict[str, Any]]] = None,
164165
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
165-
"""Retrieve a list of collections from Opensearch, supporting pagination.
166+
"""Retrieve a list of collections from Elasticsearch, supporting pagination.
166167
167168
Args:
168169
token (Optional[str]): The pagination token.
169170
limit (int): The number of results to return.
170171
request (Request): The FastAPI request object.
171172
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
172173
q (Optional[List[str]]): Free text search terms.
173-
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
174+
filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format.
175+
query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
174176
175177
Returns:
176178
A tuple of (collections, next pagination token if any).
@@ -193,7 +195,7 @@ async def get_all_collections(
193195
raise HTTPException(
194196
status_code=400,
195197
detail=f"Field '{field}' is not sortable. Sortable fields are: {', '.join(sortable_fields)}. "
196-
+ "Text fields are not sortable by default in Opensearch. "
198+
+ "Text fields are not sortable by default in Elasticsearch. "
197199
+ "To make a field sortable, update the mapping to use 'keyword' type or add a '.keyword' subfield. ",
198200
)
199201
formatted_sort.append({field: {"order": direction}})
@@ -250,11 +252,53 @@ async def get_all_collections(
250252
# Convert string filter to dict if needed
251253
if isinstance(filter, str):
252254
filter = orjson.loads(filter)
253-
# Convert the filter to an Opensearch query using the filter module
255+
# Convert the filter to an Elasticsearch query using the filter module
254256
es_query = filter_module.to_es(await self.get_queryables_mapping(), filter)
255257
query_parts.append(es_query)
256258

257-
# Combine all query parts with AND logic
259+
# Apply query extension if provided
260+
if query:
261+
try:
262+
# Process each field and operator in the query
263+
for field_name, expr in query.items():
264+
for op, value in expr.items():
265+
# Handle different operators
266+
if op == "eq":
267+
# Equality operator
268+
# Use different query types based on field name
269+
if field_name in ["title", "description"]:
270+
# For text fields, use match_phrase for exact phrase matching
271+
query_part = {"match_phrase": {field_name: value}}
272+
else:
273+
# For other fields, use term query for exact matching
274+
query_part = {"term": {field_name: value}}
275+
query_parts.append(query_part)
276+
elif op == "neq":
277+
# Not equal operator
278+
query_part = {
279+
"bool": {"must_not": [{"term": {field_name: value}}]}
280+
}
281+
query_parts.append(query_part)
282+
elif op in ["lt", "lte", "gt", "gte"]:
283+
# Range operators
284+
query_parts.append({"range": {field_name: {op: value}}})
285+
elif op == "in":
286+
# In operator (value should be a list)
287+
if isinstance(value, list):
288+
query_parts.append({"terms": {field_name: value}})
289+
else:
290+
query_parts.append({"term": {field_name: value}})
291+
elif op == "contains":
292+
# Contains operator for arrays
293+
query_parts.append({"term": {field_name: value}})
294+
except Exception as e:
295+
logger = logging.getLogger(__name__)
296+
logger.error(f"Error converting query to Elasticsearch: {e}")
297+
# If there's an error, add a query that matches nothing
298+
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
299+
raise
300+
301+
# Combine all query parts with AND logic if there are multiple
258302
if query_parts:
259303
body["query"] = (
260304
query_parts[0]

0 commit comments

Comments
 (0)