3636
3737
3838@attr .s
39- class EsAggregationExtensionGetRequest (AggregationExtensionGetRequest , FilterExtensionGetRequest ):
39+ class EsAggregationExtensionGetRequest (
40+ AggregationExtensionGetRequest , FilterExtensionGetRequest
41+ ):
4042 """Implementation specific query parameters for aggregation precision."""
4143
42- collection_id : Optional [Annotated [str , Path (description = "Collection ID" )]] = attr .ib (default = None )
44+ collection_id : Optional [
45+ Annotated [str , Path (description = "Collection ID" )]
46+ ] = attr .ib (default = None )
4347
4448 centroid_geohash_grid_frequency_precision : Optional [int ] = attr .ib (default = None )
4549 centroid_geohex_grid_frequency_precision : Optional [int ] = attr .ib (default = None )
@@ -49,7 +53,9 @@ class EsAggregationExtensionGetRequest(AggregationExtensionGetRequest, FilterExt
4953 datetime_frequency_interval : Optional [str ] = attr .ib (default = None )
5054
5155
52- class EsAggregationExtensionPostRequest (AggregationExtensionPostRequest , FilterExtensionPostRequest ):
56+ class EsAggregationExtensionPostRequest (
57+ AggregationExtensionPostRequest , FilterExtensionPostRequest
58+ ):
5359 """Implementation specific query parameters for aggregation precision."""
5460
5561 centroid_geohash_grid_frequency_precision : Optional [int ] = None
@@ -147,7 +153,9 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
147153 )
148154 if await self .database .check_collection_exists (collection_id ) is None :
149155 collection = await self .database .find_collection (collection_id )
150- aggregations = collection .get ("aggregations" , self .DEFAULT_AGGREGATIONS .copy ())
156+ aggregations = collection .get (
157+ "aggregations" , self .DEFAULT_AGGREGATIONS .copy ()
158+ )
151159 else :
152160 raise IndexError (f"Collection { collection_id } does not exist" )
153161 else :
@@ -160,9 +168,13 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
160168 )
161169
162170 aggregations = self .DEFAULT_AGGREGATIONS
163- return AggregationCollection (type = "AggregationCollection" , aggregations = aggregations , links = links )
171+ return AggregationCollection (
172+ type = "AggregationCollection" , aggregations = aggregations , links = links
173+ )
164174
165- def extract_precision (self , precision : Union [int , None ], min_value : int , max_value : int ) -> Optional [int ]:
175+ def extract_precision (
176+ self , precision : Union [int , None ], min_value : int , max_value : int
177+ ) -> Optional [int ]:
166178 """Ensure that the aggregation precision value is withing the a valid range, otherwise return the minumium value."""
167179 if precision is not None :
168180 if precision < min_value or precision > max_value :
@@ -199,7 +211,9 @@ def extract_date_histogram_interval(self, value: Optional[str]) -> str:
199211 return self .DEFAULT_DATETIME_INTERVAL
200212
201213 @staticmethod
202- def _return_date (interval : Optional [Union [DateTimeType , str ]]) -> Dict [str , Optional [str ]]:
214+ def _return_date (
215+ interval : Optional [Union [DateTimeType , str ]]
216+ ) -> Dict [str , Optional [str ]]:
203217 """
204218 Convert a date interval.
205219
@@ -227,7 +241,9 @@ def _return_date(interval: Optional[Union[DateTimeType, str]]) -> Dict[str, Opti
227241 if "/" in interval :
228242 parts = interval .split ("/" )
229243 result ["gte" ] = parts [0 ] if parts [0 ] != ".." else None
230- result ["lte" ] = parts [1 ] if len (parts ) > 1 and parts [1 ] != ".." else None
244+ result ["lte" ] = (
245+ parts [1 ] if len (parts ) > 1 and parts [1 ] != ".." else None
246+ )
231247 else :
232248 converted_time = interval if interval != ".." else None
233249 result ["gte" ] = result ["lte" ] = converted_time
@@ -267,7 +283,9 @@ def frequency_agg(self, es_aggs, name, data_type):
267283
268284 def metric_agg (self , es_aggs , name , data_type ):
269285 """Format an aggregation for a metric aggregation."""
270- value = es_aggs .get (name , {}).get ("value_as_string" ) or es_aggs .get (name , {}).get ("value" )
286+ value = es_aggs .get (name , {}).get ("value_as_string" ) or es_aggs .get (
287+ name , {}
288+ ).get ("value" )
271289 # ES 7.x does not return datetimes with a 'value_as_string' field
272290 if "datetime" in name and isinstance (value , float ):
273291 value = datetime_to_str (datetime .fromtimestamp (value / 1e3 ))
@@ -313,7 +331,9 @@ def format_datetime(dt):
313331 async def aggregate (
314332 self ,
315333 aggregate_request : Optional [EsAggregationExtensionPostRequest ] = None ,
316- collection_id : Optional [Annotated [str , Path (description = "Collection ID" )]] = None ,
334+ collection_id : Optional [
335+ Annotated [str , Path (description = "Collection ID" )]
336+ ] = None ,
317337 collections : Optional [List [str ]] = [],
318338 datetime : Optional [DateTimeType ] = None ,
319339 intersects : Optional [str ] = None ,
@@ -370,7 +390,9 @@ async def aggregate(
370390
371391 filter_lang = "cql2-json"
372392 if aggregate_request .filter_expr :
373- aggregate_request .filter_expr = self .get_filter (aggregate_request .filter_expr , filter_lang )
393+ aggregate_request .filter_expr = self .get_filter (
394+ aggregate_request .filter_expr , filter_lang
395+ )
374396
375397 if collection_id :
376398 if aggregate_request .collections :
@@ -381,18 +403,25 @@ async def aggregate(
381403 else :
382404 aggregate_request .collections = [collection_id ]
383405
384- if aggregate_request .aggregations is None or aggregate_request .aggregations == []:
406+ if (
407+ aggregate_request .aggregations is None
408+ or aggregate_request .aggregations == []
409+ ):
385410 raise HTTPException (
386411 status_code = 400 ,
387412 detail = "No 'aggregations' found. Use '/aggregations' to return available aggregations" ,
388413 )
389414
390415 if aggregate_request .ids :
391- search = self .database .apply_ids_filter (search = search , item_ids = aggregate_request .ids )
416+ search = self .database .apply_ids_filter (
417+ search = search , item_ids = aggregate_request .ids
418+ )
392419
393420 if aggregate_request .datetime :
394421 datetime_search = self ._return_date (aggregate_request .datetime )
395- search = self .database .apply_datetime_filter (search = search , datetime_search = datetime_search )
422+ search = self .database .apply_datetime_filter (
423+ search = search , datetime_search = datetime_search
424+ )
396425
397426 if aggregate_request .bbox :
398427 bbox = aggregate_request .bbox
@@ -402,14 +431,22 @@ async def aggregate(
402431 search = self .database .apply_bbox_filter (search = search , bbox = bbox )
403432
404433 if aggregate_request .intersects :
405- search = self .database .apply_intersects_filter (search = search , intersects = aggregate_request .intersects )
434+ search = self .database .apply_intersects_filter (
435+ search = search , intersects = aggregate_request .intersects
436+ )
406437
407438 if aggregate_request .collections :
408- search = self .database .apply_collections_filter (search = search , collection_ids = aggregate_request .collections )
439+ search = self .database .apply_collections_filter (
440+ search = search , collection_ids = aggregate_request .collections
441+ )
409442 # validate that aggregations are supported for all collections
410443 for collection_id in aggregate_request .collections :
411- aggs = await self .get_aggregations (collection_id = collection_id , request = request )
412- supported_aggregations = aggs ["aggregations" ] + self .DEFAULT_AGGREGATIONS
444+ aggs = await self .get_aggregations (
445+ collection_id = collection_id , request = request
446+ )
447+ supported_aggregations = (
448+ aggs ["aggregations" ] + self .DEFAULT_AGGREGATIONS
449+ )
413450
414451 for agg_name in aggregate_request .aggregations :
415452 if agg_name not in set ([x ["name" ] for x in supported_aggregations ]):
@@ -430,9 +467,13 @@ async def aggregate(
430467
431468 if aggregate_request .filter_expr :
432469 try :
433- search = self .database .apply_cql2_filter (search , aggregate_request .filter_expr )
470+ search = self .database .apply_cql2_filter (
471+ search , aggregate_request .filter_expr
472+ )
434473 except Exception as e :
435- raise HTTPException (status_code = 400 , detail = f"Error with cql2 filter: { e } " )
474+ raise HTTPException (
475+ status_code = 400 , detail = f"Error with cql2 filter: { e } "
476+ )
436477
437478 centroid_geohash_grid_precision = self .extract_precision (
438479 aggregate_request .centroid_geohash_grid_frequency_precision ,
@@ -487,13 +528,20 @@ async def aggregate(
487528 if db_response :
488529 result_aggs = db_response .get ("aggregations" , {})
489530 for agg in {
490- frozenset (item .items ()): item for item in supported_aggregations + self .GEO_POINT_AGGREGATIONS
531+ frozenset (item .items ()): item
532+ for item in supported_aggregations + self .GEO_POINT_AGGREGATIONS
491533 }.values ():
492534 if agg ["name" ] in aggregate_request .aggregations :
493535 if agg ["name" ].endswith ("_frequency" ):
494- aggs .append (self .frequency_agg (result_aggs , agg ["name" ], agg ["data_type" ]))
536+ aggs .append (
537+ self .frequency_agg (
538+ result_aggs , agg ["name" ], agg ["data_type" ]
539+ )
540+ )
495541 else :
496- aggs .append (self .metric_agg (result_aggs , agg ["name" ], agg ["data_type" ]))
542+ aggs .append (
543+ self .metric_agg (result_aggs , agg ["name" ], agg ["data_type" ])
544+ )
497545 links = [
498546 {"rel" : "root" , "type" : "application/json" , "href" : base_url },
499547 ]
@@ -522,6 +570,8 @@ async def aggregate(
522570 "href" : urljoin (base_url , "aggregate" ),
523571 }
524572 )
525- results = AggregationCollection (type = "AggregationCollection" , aggregations = aggs , links = links )
573+ results = AggregationCollection (
574+ type = "AggregationCollection" , aggregations = aggs , links = links
575+ )
526576
527577 return results
0 commit comments