Skip to content

Commit bf6f96a

Browse files
committed
pre-commit.
1 parent 93350c7 commit bf6f96a

File tree

2 files changed

+157
-53
lines changed

2 files changed

+157
-53
lines changed

stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@
3636

3737

3838
@attr.s
39-
class EsAggregationExtensionGetRequest(AggregationExtensionGetRequest, FilterExtensionGetRequest):
39+
class EsAggregationExtensionGetRequest(
40+
AggregationExtensionGetRequest, FilterExtensionGetRequest
41+
):
4042
"""Implementation specific query parameters for aggregation precision."""
4143

42-
collection_id: Optional[Annotated[str, Path(description="Collection ID")]] = attr.ib(default=None)
44+
collection_id: Optional[
45+
Annotated[str, Path(description="Collection ID")]
46+
] = attr.ib(default=None)
4347

4448
centroid_geohash_grid_frequency_precision: Optional[int] = attr.ib(default=None)
4549
centroid_geohex_grid_frequency_precision: Optional[int] = attr.ib(default=None)
@@ -49,7 +53,9 @@ class EsAggregationExtensionGetRequest(AggregationExtensionGetRequest, FilterExt
4953
datetime_frequency_interval: Optional[str] = attr.ib(default=None)
5054

5155

52-
class EsAggregationExtensionPostRequest(AggregationExtensionPostRequest, FilterExtensionPostRequest):
56+
class EsAggregationExtensionPostRequest(
57+
AggregationExtensionPostRequest, FilterExtensionPostRequest
58+
):
5359
"""Implementation specific query parameters for aggregation precision."""
5460

5561
centroid_geohash_grid_frequency_precision: Optional[int] = None
@@ -147,7 +153,9 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
147153
)
148154
if await self.database.check_collection_exists(collection_id) is None:
149155
collection = await self.database.find_collection(collection_id)
150-
aggregations = collection.get("aggregations", self.DEFAULT_AGGREGATIONS.copy())
156+
aggregations = collection.get(
157+
"aggregations", self.DEFAULT_AGGREGATIONS.copy()
158+
)
151159
else:
152160
raise IndexError(f"Collection {collection_id} does not exist")
153161
else:
@@ -160,9 +168,13 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
160168
)
161169

162170
aggregations = self.DEFAULT_AGGREGATIONS
163-
return AggregationCollection(type="AggregationCollection", aggregations=aggregations, links=links)
171+
return AggregationCollection(
172+
type="AggregationCollection", aggregations=aggregations, links=links
173+
)
164174

165-
def extract_precision(self, precision: Union[int, None], min_value: int, max_value: int) -> Optional[int]:
175+
def extract_precision(
176+
self, precision: Union[int, None], min_value: int, max_value: int
177+
) -> Optional[int]:
166178
"""Ensure that the aggregation precision value is withing the a valid range, otherwise return the minumium value."""
167179
if precision is not None:
168180
if precision < min_value or precision > max_value:
@@ -199,7 +211,9 @@ def extract_date_histogram_interval(self, value: Optional[str]) -> str:
199211
return self.DEFAULT_DATETIME_INTERVAL
200212

201213
@staticmethod
202-
def _return_date(interval: Optional[Union[DateTimeType, str]]) -> Dict[str, Optional[str]]:
214+
def _return_date(
215+
interval: Optional[Union[DateTimeType, str]]
216+
) -> Dict[str, Optional[str]]:
203217
"""
204218
Convert a date interval.
205219
@@ -227,7 +241,9 @@ def _return_date(interval: Optional[Union[DateTimeType, str]]) -> Dict[str, Opti
227241
if "/" in interval:
228242
parts = interval.split("/")
229243
result["gte"] = parts[0] if parts[0] != ".." else None
230-
result["lte"] = parts[1] if len(parts) > 1 and parts[1] != ".." else None
244+
result["lte"] = (
245+
parts[1] if len(parts) > 1 and parts[1] != ".." else None
246+
)
231247
else:
232248
converted_time = interval if interval != ".." else None
233249
result["gte"] = result["lte"] = converted_time
@@ -267,7 +283,9 @@ def frequency_agg(self, es_aggs, name, data_type):
267283

268284
def metric_agg(self, es_aggs, name, data_type):
269285
"""Format an aggregation for a metric aggregation."""
270-
value = es_aggs.get(name, {}).get("value_as_string") or es_aggs.get(name, {}).get("value")
286+
value = es_aggs.get(name, {}).get("value_as_string") or es_aggs.get(
287+
name, {}
288+
).get("value")
271289
# ES 7.x does not return datetimes with a 'value_as_string' field
272290
if "datetime" in name and isinstance(value, float):
273291
value = datetime_to_str(datetime.fromtimestamp(value / 1e3))
@@ -313,7 +331,9 @@ def format_datetime(dt):
313331
async def aggregate(
314332
self,
315333
aggregate_request: Optional[EsAggregationExtensionPostRequest] = None,
316-
collection_id: Optional[Annotated[str, Path(description="Collection ID")]] = None,
334+
collection_id: Optional[
335+
Annotated[str, Path(description="Collection ID")]
336+
] = None,
317337
collections: Optional[List[str]] = [],
318338
datetime: Optional[DateTimeType] = None,
319339
intersects: Optional[str] = None,
@@ -370,7 +390,9 @@ async def aggregate(
370390

371391
filter_lang = "cql2-json"
372392
if aggregate_request.filter_expr:
373-
aggregate_request.filter_expr = self.get_filter(aggregate_request.filter_expr, filter_lang)
393+
aggregate_request.filter_expr = self.get_filter(
394+
aggregate_request.filter_expr, filter_lang
395+
)
374396

375397
if collection_id:
376398
if aggregate_request.collections:
@@ -381,18 +403,25 @@ async def aggregate(
381403
else:
382404
aggregate_request.collections = [collection_id]
383405

384-
if aggregate_request.aggregations is None or aggregate_request.aggregations == []:
406+
if (
407+
aggregate_request.aggregations is None
408+
or aggregate_request.aggregations == []
409+
):
385410
raise HTTPException(
386411
status_code=400,
387412
detail="No 'aggregations' found. Use '/aggregations' to return available aggregations",
388413
)
389414

390415
if aggregate_request.ids:
391-
search = self.database.apply_ids_filter(search=search, item_ids=aggregate_request.ids)
416+
search = self.database.apply_ids_filter(
417+
search=search, item_ids=aggregate_request.ids
418+
)
392419

393420
if aggregate_request.datetime:
394421
datetime_search = self._return_date(aggregate_request.datetime)
395-
search = self.database.apply_datetime_filter(search=search, datetime_search=datetime_search)
422+
search = self.database.apply_datetime_filter(
423+
search=search, datetime_search=datetime_search
424+
)
396425

397426
if aggregate_request.bbox:
398427
bbox = aggregate_request.bbox
@@ -402,14 +431,22 @@ async def aggregate(
402431
search = self.database.apply_bbox_filter(search=search, bbox=bbox)
403432

404433
if aggregate_request.intersects:
405-
search = self.database.apply_intersects_filter(search=search, intersects=aggregate_request.intersects)
434+
search = self.database.apply_intersects_filter(
435+
search=search, intersects=aggregate_request.intersects
436+
)
406437

407438
if aggregate_request.collections:
408-
search = self.database.apply_collections_filter(search=search, collection_ids=aggregate_request.collections)
439+
search = self.database.apply_collections_filter(
440+
search=search, collection_ids=aggregate_request.collections
441+
)
409442
# validate that aggregations are supported for all collections
410443
for collection_id in aggregate_request.collections:
411-
aggs = await self.get_aggregations(collection_id=collection_id, request=request)
412-
supported_aggregations = aggs["aggregations"] + self.DEFAULT_AGGREGATIONS
444+
aggs = await self.get_aggregations(
445+
collection_id=collection_id, request=request
446+
)
447+
supported_aggregations = (
448+
aggs["aggregations"] + self.DEFAULT_AGGREGATIONS
449+
)
413450

414451
for agg_name in aggregate_request.aggregations:
415452
if agg_name not in set([x["name"] for x in supported_aggregations]):
@@ -430,9 +467,13 @@ async def aggregate(
430467

431468
if aggregate_request.filter:
432469
try:
433-
search = self.database.apply_cql2_filter(search, aggregate_request.filter)
470+
search = self.database.apply_cql2_filter(
471+
search, aggregate_request.filter
472+
)
434473
except Exception as e:
435-
raise HTTPException(status_code=400, detail=f"Error with cql2 filter: {e}")
474+
raise HTTPException(
475+
status_code=400, detail=f"Error with cql2 filter: {e}"
476+
)
436477

437478
centroid_geohash_grid_precision = self.extract_precision(
438479
aggregate_request.centroid_geohash_grid_frequency_precision,
@@ -487,13 +528,20 @@ async def aggregate(
487528
if db_response:
488529
result_aggs = db_response.get("aggregations", {})
489530
for agg in {
490-
frozenset(item.items()): item for item in supported_aggregations + self.GEO_POINT_AGGREGATIONS
531+
frozenset(item.items()): item
532+
for item in supported_aggregations + self.GEO_POINT_AGGREGATIONS
491533
}.values():
492534
if agg["name"] in aggregate_request.aggregations:
493535
if agg["name"].endswith("_frequency"):
494-
aggs.append(self.frequency_agg(result_aggs, agg["name"], agg["data_type"]))
536+
aggs.append(
537+
self.frequency_agg(
538+
result_aggs, agg["name"], agg["data_type"]
539+
)
540+
)
495541
else:
496-
aggs.append(self.metric_agg(result_aggs, agg["name"], agg["data_type"]))
542+
aggs.append(
543+
self.metric_agg(result_aggs, agg["name"], agg["data_type"])
544+
)
497545
links = [
498546
{"rel": "root", "type": "application/json", "href": base_url},
499547
]
@@ -522,6 +570,8 @@ async def aggregate(
522570
"href": urljoin(base_url, "aggregate"),
523571
}
524572
)
525-
results = AggregationCollection(type="AggregationCollection", aggregations=aggs, links=links)
573+
results = AggregationCollection(
574+
type="AggregationCollection", aggregations=aggs, links=links
575+
)
526576

527577
return results

0 commit comments

Comments
 (0)