3636
3737
3838@attr .s
39- class EsAggregationExtensionGetRequest (
40- AggregationExtensionGetRequest , FilterExtensionGetRequest
41- ):
39+ class EsAggregationExtensionGetRequest (AggregationExtensionGetRequest , FilterExtensionGetRequest ):
4240 """Implementation specific query parameters for aggregation precision."""
4341
44- collection_id : Optional [
45- Annotated [str , Path (description = "Collection ID" )]
46- ] = attr .ib (default = None )
42+ collection_id : Optional [Annotated [str , Path (description = "Collection ID" )]] = attr .ib (default = None )
4743
4844 centroid_geohash_grid_frequency_precision : Optional [int ] = attr .ib (default = None )
4945 centroid_geohex_grid_frequency_precision : Optional [int ] = attr .ib (default = None )
@@ -53,9 +49,7 @@ class EsAggregationExtensionGetRequest(
5349 datetime_frequency_interval : Optional [str ] = attr .ib (default = None )
5450
5551
56- class EsAggregationExtensionPostRequest (
57- AggregationExtensionPostRequest , FilterExtensionPostRequest
58- ):
52+ class EsAggregationExtensionPostRequest (AggregationExtensionPostRequest , FilterExtensionPostRequest ):
5953 """Implementation specific query parameters for aggregation precision."""
6054
6155 centroid_geohash_grid_frequency_precision : Optional [int ] = None
@@ -153,9 +147,7 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
153147 )
154148 if await self .database .check_collection_exists (collection_id ) is None :
155149 collection = await self .database .find_collection (collection_id )
156- aggregations = collection .get (
157- "aggregations" , self .DEFAULT_AGGREGATIONS .copy ()
158- )
150+ aggregations = collection .get ("aggregations" , self .DEFAULT_AGGREGATIONS .copy ())
159151 else :
160152 raise IndexError (f"Collection { collection_id } does not exist" )
161153 else :
@@ -168,13 +160,9 @@ async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
168160 )
169161
170162 aggregations = self .DEFAULT_AGGREGATIONS
171- return AggregationCollection (
172- type = "AggregationCollection" , aggregations = aggregations , links = links
173- )
163+ return AggregationCollection (type = "AggregationCollection" , aggregations = aggregations , links = links )
174164
175- def extract_precision (
176- self , precision : Union [int , None ], min_value : int , max_value : int
177- ) -> Optional [int ]:
165+ def extract_precision (self , precision : Union [int , None ], min_value : int , max_value : int ) -> Optional [int ]:
178166 """Ensure that the aggregation precision value is withing the a valid range, otherwise return the minumium value."""
179167 if precision is not None :
180168 if precision < min_value or precision > max_value :
@@ -211,9 +199,7 @@ def extract_date_histogram_interval(self, value: Optional[str]) -> str:
211199 return self .DEFAULT_DATETIME_INTERVAL
212200
213201 @staticmethod
214- def _return_date (
215- interval : Optional [Union [DateTimeType , str ]]
216- ) -> Dict [str , Optional [str ]]:
202+ def _return_date (interval : Optional [Union [DateTimeType , str ]]) -> Dict [str , Optional [str ]]:
217203 """
218204 Convert a date interval.
219205
@@ -241,9 +227,7 @@ def _return_date(
241227 if "/" in interval :
242228 parts = interval .split ("/" )
243229 result ["gte" ] = parts [0 ] if parts [0 ] != ".." else None
244- result ["lte" ] = (
245- parts [1 ] if len (parts ) > 1 and parts [1 ] != ".." else None
246- )
230+ result ["lte" ] = parts [1 ] if len (parts ) > 1 and parts [1 ] != ".." else None
247231 else :
248232 converted_time = interval if interval != ".." else None
249233 result ["gte" ] = result ["lte" ] = converted_time
@@ -283,9 +267,7 @@ def frequency_agg(self, es_aggs, name, data_type):
283267
284268 def metric_agg (self , es_aggs , name , data_type ):
285269 """Format an aggregation for a metric aggregation."""
286- value = es_aggs .get (name , {}).get ("value_as_string" ) or es_aggs .get (
287- name , {}
288- ).get ("value" )
270+ value = es_aggs .get (name , {}).get ("value_as_string" ) or es_aggs .get (name , {}).get ("value" )
289271 # ES 7.x does not return datetimes with a 'value_as_string' field
290272 if "datetime" in name and isinstance (value , float ):
291273 value = datetime_to_str (datetime .fromtimestamp (value / 1e3 ))
@@ -331,9 +313,7 @@ def format_datetime(dt):
331313 async def aggregate (
332314 self ,
333315 aggregate_request : Optional [EsAggregationExtensionPostRequest ] = None ,
334- collection_id : Optional [
335- Annotated [str , Path (description = "Collection ID" )]
336- ] = None ,
316+ collection_id : Optional [Annotated [str , Path (description = "Collection ID" )]] = None ,
337317 collections : Optional [List [str ]] = [],
338318 datetime : Optional [DateTimeType ] = None ,
339319 intersects : Optional [str ] = None ,
@@ -389,10 +369,8 @@ async def aggregate(
389369 collection_id = path .split ("/" )[2 ]
390370
391371 filter_lang = "cql2-json"
392- if aggregate_request .filter :
393- aggregate_request .filter = self .get_filter (
394- aggregate_request .filter , filter_lang
395- )
372+ if aggregate_request .filter_expr :
373+ aggregate_request .filter_expr = self .get_filter (aggregate_request .filter_expr , filter_lang )
396374
397375 if collection_id :
398376 if aggregate_request .collections :
@@ -403,25 +381,18 @@ async def aggregate(
403381 else :
404382 aggregate_request .collections = [collection_id ]
405383
406- if (
407- aggregate_request .aggregations is None
408- or aggregate_request .aggregations == []
409- ):
384+ if aggregate_request .aggregations is None or aggregate_request .aggregations == []:
410385 raise HTTPException (
411386 status_code = 400 ,
412387 detail = "No 'aggregations' found. Use '/aggregations' to return available aggregations" ,
413388 )
414389
415390 if aggregate_request .ids :
416- search = self .database .apply_ids_filter (
417- search = search , item_ids = aggregate_request .ids
418- )
391+ search = self .database .apply_ids_filter (search = search , item_ids = aggregate_request .ids )
419392
420393 if aggregate_request .datetime :
421394 datetime_search = self ._return_date (aggregate_request .datetime )
422- search = self .database .apply_datetime_filter (
423- search = search , datetime_search = datetime_search
424- )
395+ search = self .database .apply_datetime_filter (search = search , datetime_search = datetime_search )
425396
426397 if aggregate_request .bbox :
427398 bbox = aggregate_request .bbox
@@ -431,22 +402,14 @@ async def aggregate(
431402 search = self .database .apply_bbox_filter (search = search , bbox = bbox )
432403
433404 if aggregate_request .intersects :
434- search = self .database .apply_intersects_filter (
435- search = search , intersects = aggregate_request .intersects
436- )
405+ search = self .database .apply_intersects_filter (search = search , intersects = aggregate_request .intersects )
437406
438407 if aggregate_request .collections :
439- search = self .database .apply_collections_filter (
440- search = search , collection_ids = aggregate_request .collections
441- )
408+ search = self .database .apply_collections_filter (search = search , collection_ids = aggregate_request .collections )
442409 # validate that aggregations are supported for all collections
443410 for collection_id in aggregate_request .collections :
444- aggs = await self .get_aggregations (
445- collection_id = collection_id , request = request
446- )
447- supported_aggregations = (
448- aggs ["aggregations" ] + self .DEFAULT_AGGREGATIONS
449- )
411+ aggs = await self .get_aggregations (collection_id = collection_id , request = request )
412+ supported_aggregations = aggs ["aggregations" ] + self .DEFAULT_AGGREGATIONS
450413
451414 for agg_name in aggregate_request .aggregations :
452415 if agg_name not in set ([x ["name" ] for x in supported_aggregations ]):
@@ -467,13 +430,9 @@ async def aggregate(
467430
468431 if aggregate_request .filter :
469432 try :
470- search = self .database .apply_cql2_filter (
471- search , aggregate_request .filter
472- )
433+ search = self .database .apply_cql2_filter (search , aggregate_request .filter )
473434 except Exception as e :
474- raise HTTPException (
475- status_code = 400 , detail = f"Error with cql2 filter: { e } "
476- )
435+ raise HTTPException (status_code = 400 , detail = f"Error with cql2 filter: { e } " )
477436
478437 centroid_geohash_grid_precision = self .extract_precision (
479438 aggregate_request .centroid_geohash_grid_frequency_precision ,
@@ -528,20 +487,13 @@ async def aggregate(
528487 if db_response :
529488 result_aggs = db_response .get ("aggregations" , {})
530489 for agg in {
531- frozenset (item .items ()): item
532- for item in supported_aggregations + self .GEO_POINT_AGGREGATIONS
490+ frozenset (item .items ()): item for item in supported_aggregations + self .GEO_POINT_AGGREGATIONS
533491 }.values ():
534492 if agg ["name" ] in aggregate_request .aggregations :
535493 if agg ["name" ].endswith ("_frequency" ):
536- aggs .append (
537- self .frequency_agg (
538- result_aggs , agg ["name" ], agg ["data_type" ]
539- )
540- )
494+ aggs .append (self .frequency_agg (result_aggs , agg ["name" ], agg ["data_type" ]))
541495 else :
542- aggs .append (
543- self .metric_agg (result_aggs , agg ["name" ], agg ["data_type" ])
544- )
496+ aggs .append (self .metric_agg (result_aggs , agg ["name" ], agg ["data_type" ]))
545497 links = [
546498 {"rel" : "root" , "type" : "application/json" , "href" : base_url },
547499 ]
@@ -570,8 +522,6 @@ async def aggregate(
570522 "href" : urljoin (base_url , "aggregate" ),
571523 }
572524 )
573- results = AggregationCollection (
574- type = "AggregationCollection" , aggregations = aggs , links = links
575- )
525+ results = AggregationCollection (type = "AggregationCollection" , aggregations = aggs , links = links )
576526
577527 return results
0 commit comments