Skip to content

Commit 29f6e2e

Browse files
committed
filter_expr not filter for aggregation request.
1 parent bf6f96a commit 29f6e2e

File tree

1 file changed

+25
-75
lines changed

1 file changed

+25
-75
lines changed

stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py

Lines changed: 25 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,10 @@
3636

3737

3838
@attr.s
39-
class EsAggregationExtensionGetRequest(
40-
AggregationExtensionGetRequest, FilterExtensionGetRequest
41-
):
39+
class EsAggregationExtensionGetRequest(AggregationExtensionGetRequest, FilterExtensionGetRequest):
4240
"""Implementation specific query parameters for aggregation precision."""
4341

44-
collection_id: Optional[
45-
Annotated[str, Path(description="Collection ID")]
46-
] = attr.ib(default=None)
42+
collection_id: Optional[Annotated[str, Path(description="Collection ID")]] = attr.ib(default=None)
4743

4844
centroid_geohash_grid_frequency_precision: Optional[int] = attr.ib(default=None)
4945
centroid_geohex_grid_frequency_precision: Optional[int] = attr.ib(default=None)
@@ -53,9 +49,7 @@ class EsAggregationExtensionGetRequest(
5349
datetime_frequency_interval: Optional[str] = attr.ib(default=None)
5450

5551

56-
class EsAggregationExtensionPostRequest(
57-
AggregationExtensionPostRequest, FilterExtensionPostRequest
58-
):
52+
class EsAggregationExtensionPostRequest(AggregationExtensionPostRequest, FilterExtensionPostRequest):
5953
"""Implementation specific query parameters for aggregation precision."""
6054

6155
centroid_geohash_grid_frequency_precision: Optional[int] = None
@@ -153,9 +147,7 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
153147
)
154148
if await self.database.check_collection_exists(collection_id) is None:
155149
collection = await self.database.find_collection(collection_id)
156-
aggregations = collection.get(
157-
"aggregations", self.DEFAULT_AGGREGATIONS.copy()
158-
)
150+
aggregations = collection.get("aggregations", self.DEFAULT_AGGREGATIONS.copy())
159151
else:
160152
raise IndexError(f"Collection {collection_id} does not exist")
161153
else:
@@ -168,13 +160,9 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
168160
)
169161

170162
aggregations = self.DEFAULT_AGGREGATIONS
171-
return AggregationCollection(
172-
type="AggregationCollection", aggregations=aggregations, links=links
173-
)
163+
return AggregationCollection(type="AggregationCollection", aggregations=aggregations, links=links)
174164

175-
def extract_precision(
176-
self, precision: Union[int, None], min_value: int, max_value: int
177-
) -> Optional[int]:
165+
def extract_precision(self, precision: Union[int, None], min_value: int, max_value: int) -> Optional[int]:
178166
"""Ensure that the aggregation precision value is withing the a valid range, otherwise return the minumium value."""
179167
if precision is not None:
180168
if precision < min_value or precision > max_value:
@@ -211,9 +199,7 @@ def extract_date_histogram_interval(self, value: Optional[str]) -> str:
211199
return self.DEFAULT_DATETIME_INTERVAL
212200

213201
@staticmethod
214-
def _return_date(
215-
interval: Optional[Union[DateTimeType, str]]
216-
) -> Dict[str, Optional[str]]:
202+
def _return_date(interval: Optional[Union[DateTimeType, str]]) -> Dict[str, Optional[str]]:
217203
"""
218204
Convert a date interval.
219205
@@ -241,9 +227,7 @@ def _return_date(
241227
if "/" in interval:
242228
parts = interval.split("/")
243229
result["gte"] = parts[0] if parts[0] != ".." else None
244-
result["lte"] = (
245-
parts[1] if len(parts) > 1 and parts[1] != ".." else None
246-
)
230+
result["lte"] = parts[1] if len(parts) > 1 and parts[1] != ".." else None
247231
else:
248232
converted_time = interval if interval != ".." else None
249233
result["gte"] = result["lte"] = converted_time
@@ -283,9 +267,7 @@ def frequency_agg(self, es_aggs, name, data_type):
283267

284268
def metric_agg(self, es_aggs, name, data_type):
285269
"""Format an aggregation for a metric aggregation."""
286-
value = es_aggs.get(name, {}).get("value_as_string") or es_aggs.get(
287-
name, {}
288-
).get("value")
270+
value = es_aggs.get(name, {}).get("value_as_string") or es_aggs.get(name, {}).get("value")
289271
# ES 7.x does not return datetimes with a 'value_as_string' field
290272
if "datetime" in name and isinstance(value, float):
291273
value = datetime_to_str(datetime.fromtimestamp(value / 1e3))
@@ -331,9 +313,7 @@ def format_datetime(dt):
331313
async def aggregate(
332314
self,
333315
aggregate_request: Optional[EsAggregationExtensionPostRequest] = None,
334-
collection_id: Optional[
335-
Annotated[str, Path(description="Collection ID")]
336-
] = None,
316+
collection_id: Optional[Annotated[str, Path(description="Collection ID")]] = None,
337317
collections: Optional[List[str]] = [],
338318
datetime: Optional[DateTimeType] = None,
339319
intersects: Optional[str] = None,
@@ -390,9 +370,7 @@ async def aggregate(
390370

391371
filter_lang = "cql2-json"
392372
if aggregate_request.filter_expr:
393-
aggregate_request.filter_expr = self.get_filter(
394-
aggregate_request.filter_expr, filter_lang
395-
)
373+
aggregate_request.filter_expr = self.get_filter(aggregate_request.filter_expr, filter_lang)
396374

397375
if collection_id:
398376
if aggregate_request.collections:
@@ -403,25 +381,18 @@ async def aggregate(
403381
else:
404382
aggregate_request.collections = [collection_id]
405383

406-
if (
407-
aggregate_request.aggregations is None
408-
or aggregate_request.aggregations == []
409-
):
384+
if aggregate_request.aggregations is None or aggregate_request.aggregations == []:
410385
raise HTTPException(
411386
status_code=400,
412387
detail="No 'aggregations' found. Use '/aggregations' to return available aggregations",
413388
)
414389

415390
if aggregate_request.ids:
416-
search = self.database.apply_ids_filter(
417-
search=search, item_ids=aggregate_request.ids
418-
)
391+
search = self.database.apply_ids_filter(search=search, item_ids=aggregate_request.ids)
419392

420393
if aggregate_request.datetime:
421394
datetime_search = self._return_date(aggregate_request.datetime)
422-
search = self.database.apply_datetime_filter(
423-
search=search, datetime_search=datetime_search
424-
)
395+
search = self.database.apply_datetime_filter(search=search, datetime_search=datetime_search)
425396

426397
if aggregate_request.bbox:
427398
bbox = aggregate_request.bbox
@@ -431,22 +402,14 @@ async def aggregate(
431402
search = self.database.apply_bbox_filter(search=search, bbox=bbox)
432403

433404
if aggregate_request.intersects:
434-
search = self.database.apply_intersects_filter(
435-
search=search, intersects=aggregate_request.intersects
436-
)
405+
search = self.database.apply_intersects_filter(search=search, intersects=aggregate_request.intersects)
437406

438407
if aggregate_request.collections:
439-
search = self.database.apply_collections_filter(
440-
search=search, collection_ids=aggregate_request.collections
441-
)
408+
search = self.database.apply_collections_filter(search=search, collection_ids=aggregate_request.collections)
442409
# validate that aggregations are supported for all collections
443410
for collection_id in aggregate_request.collections:
444-
aggs = await self.get_aggregations(
445-
collection_id=collection_id, request=request
446-
)
447-
supported_aggregations = (
448-
aggs["aggregations"] + self.DEFAULT_AGGREGATIONS
449-
)
411+
aggs = await self.get_aggregations(collection_id=collection_id, request=request)
412+
supported_aggregations = aggs["aggregations"] + self.DEFAULT_AGGREGATIONS
450413

451414
for agg_name in aggregate_request.aggregations:
452415
if agg_name not in set([x["name"] for x in supported_aggregations]):
@@ -465,15 +428,11 @@ async def aggregate(
465428
detail=f"Aggregation {agg_name} not supported at catalog level",
466429
)
467430

468-
if aggregate_request.filter:
431+
if aggregate_request.filter_expr:
469432
try:
470-
search = self.database.apply_cql2_filter(
471-
search, aggregate_request.filter
472-
)
433+
search = self.database.apply_cql2_filter(search, aggregate_request.filter_expr)
473434
except Exception as e:
474-
raise HTTPException(
475-
status_code=400, detail=f"Error with cql2 filter: {e}"
476-
)
435+
raise HTTPException(status_code=400, detail=f"Error with cql2 filter: {e}")
477436

478437
centroid_geohash_grid_precision = self.extract_precision(
479438
aggregate_request.centroid_geohash_grid_frequency_precision,
@@ -528,20 +487,13 @@ async def aggregate(
528487
if db_response:
529488
result_aggs = db_response.get("aggregations", {})
530489
for agg in {
531-
frozenset(item.items()): item
532-
for item in supported_aggregations + self.GEO_POINT_AGGREGATIONS
490+
frozenset(item.items()): item for item in supported_aggregations + self.GEO_POINT_AGGREGATIONS
533491
}.values():
534492
if agg["name"] in aggregate_request.aggregations:
535493
if agg["name"].endswith("_frequency"):
536-
aggs.append(
537-
self.frequency_agg(
538-
result_aggs, agg["name"], agg["data_type"]
539-
)
540-
)
494+
aggs.append(self.frequency_agg(result_aggs, agg["name"], agg["data_type"]))
541495
else:
542-
aggs.append(
543-
self.metric_agg(result_aggs, agg["name"], agg["data_type"])
544-
)
496+
aggs.append(self.metric_agg(result_aggs, agg["name"], agg["data_type"]))
545497
links = [
546498
{"rel": "root", "type": "application/json", "href": base_url},
547499
]
@@ -570,8 +522,6 @@ async def aggregate(
570522
"href": urljoin(base_url, "aggregate"),
571523
}
572524
)
573-
results = AggregationCollection(
574-
type="AggregationCollection", aggregations=aggs, links=links
575-
)
525+
results = AggregationCollection(type="AggregationCollection", aggregations=aggs, links=links)
576526

577527
return results

0 commit comments

Comments
 (0)