Skip to content

Commit c7d9d98

Browse files
committed
pre-commit.
1 parent 4d067fb commit c7d9d98

File tree

1 file changed

+121
-39
lines changed
  • stac_fastapi/core/stac_fastapi/core

1 file changed

+121
-39
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 121 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,16 @@ class CoreClient(AsyncBaseCoreClient):
7171
"""
7272

7373
database: BaseDatabaseLogic = attr.ib()
74-
base_conformance_classes: List[str] = attr.ib(factory=lambda: BASE_CONFORMANCE_CLASSES)
74+
base_conformance_classes: List[str] = attr.ib(
75+
factory=lambda: BASE_CONFORMANCE_CLASSES
76+
)
7577
extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list))
7678

7779
session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
7880
item_serializer: Type[ItemSerializer] = attr.ib(default=ItemSerializer)
79-
collection_serializer: Type[CollectionSerializer] = attr.ib(default=CollectionSerializer)
81+
collection_serializer: Type[CollectionSerializer] = attr.ib(
82+
default=CollectionSerializer
83+
)
8084
post_request_model = attr.ib(default=BaseSearchPostRequest)
8185
stac_version: str = attr.ib(default=STAC_VERSION)
8286
landing_page_id: str = attr.ib(default="stac-fastapi")
@@ -200,7 +204,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
200204
"rel": "service-desc",
201205
"type": "application/vnd.oai.openapi+json;version=3.0",
202206
"title": "OpenAPI service description",
203-
"href": urljoin(str(request.base_url), request.app.openapi_url.lstrip("/")),
207+
"href": urljoin(
208+
str(request.base_url), request.app.openapi_url.lstrip("/")
209+
),
204210
}
205211
)
206212

@@ -210,7 +216,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
210216
"rel": "service-doc",
211217
"type": "text/html",
212218
"title": "OpenAPI service documentation",
213-
"href": urljoin(str(request.base_url), request.app.docs_url.lstrip("/")),
219+
"href": urljoin(
220+
str(request.base_url), request.app.docs_url.lstrip("/")
221+
),
214222
}
215223
)
216224

@@ -230,7 +238,9 @@ async def all_collections(self, **kwargs) -> stac_types.Collections:
230238
limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10)))
231239
token = request.query_params.get("token")
232240

233-
collections, next_token = await self.database.get_all_collections(token=token, limit=limit, request=request)
241+
collections, next_token = await self.database.get_all_collections(
242+
token=token, limit=limit, request=request
243+
)
234244

235245
links = [
236246
{"rel": Relations.root.value, "type": MimeTypes.json, "href": base_url},
@@ -248,7 +258,9 @@ async def all_collections(self, **kwargs) -> stac_types.Collections:
248258

249259
return stac_types.Collections(collections=collections, links=links)
250260

251-
async def get_collection(self, collection_id: str, **kwargs) -> stac_types.Collection:
261+
async def get_collection(
262+
self, collection_id: str, **kwargs
263+
) -> stac_types.Collection:
252264
"""Get a collection from the database by its id.
253265
254266
Args:
@@ -301,16 +313,22 @@ async def item_collection(
301313

302314
base_url = str(request.base_url)
303315

304-
collection = await self.get_collection(collection_id=collection_id, request=request)
316+
collection = await self.get_collection(
317+
collection_id=collection_id, request=request
318+
)
305319
collection_id = collection.get("id")
306320
if collection_id is None:
307321
raise HTTPException(status_code=404, detail="Collection not found")
308322

309323
search = self.database.make_search()
310-
search = self.database.apply_collections_filter(search=search, collection_ids=[collection_id])
324+
search = self.database.apply_collections_filter(
325+
search=search, collection_ids=[collection_id]
326+
)
311327

312328
try:
313-
search, datetime_search = self.database.apply_datetime_filter(search=search, datetime=datetime)
329+
search, datetime_search = self.database.apply_datetime_filter(
330+
search=search, datetime=datetime
331+
)
314332
except (ValueError, TypeError) as e:
315333
# Handle invalid interval formats if return_date fails
316334
msg = f"Invalid interval format: {datetime}, error: {e}"
@@ -334,7 +352,9 @@ async def item_collection(
334352
datetime_search=datetime_search,
335353
)
336354

337-
items = [self.item_serializer.db_to_stac(item, base_url=base_url) for item in items]
355+
items = [
356+
self.item_serializer.db_to_stac(item, base_url=base_url) for item in items
357+
]
338358

339359
links = await PagingLinks(request=request, next=next_token).get_links()
340360

@@ -346,7 +366,9 @@ async def item_collection(
346366
numMatched=maybe_count,
347367
)
348368

349-
async def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Item:
369+
async def get_item(
370+
self, item_id: str, collection_id: str, **kwargs
371+
) -> stac_types.Item:
350372
"""Get an item from the database based on its id and collection id.
351373
352374
Args:
@@ -361,7 +383,9 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_typ
361383
NotFoundError: If the item does not exist in the specified collection.
362384
"""
363385
base_url = str(kwargs["request"].base_url)
364-
item = await self.database.get_one_item(item_id=item_id, collection_id=collection_id)
386+
item = await self.database.get_one_item(
387+
item_id=item_id, collection_id=collection_id
388+
)
365389
return self.item_serializer.db_to_stac(item, base_url)
366390

367391
async def get_search(
@@ -423,13 +447,16 @@ async def get_search(
423447

424448
if sortby:
425449
base_args["sortby"] = [
426-
{"field": sort[1:], "direction": "desc" if sort[0] == "-" else "asc"} for sort in sortby
450+
{"field": sort[1:], "direction": "desc" if sort[0] == "-" else "asc"}
451+
for sort in sortby
427452
]
428453

429454
if filter_expr:
430455
base_args["filter_lang"] = "cql2-json"
431456
base_args["filter"] = orjson.loads(
432-
unquote_plus(filter_expr) if filter_lang == "cql2-json" else to_cql2(parse_cql2_text(filter_expr))
457+
unquote_plus(filter_expr)
458+
if filter_lang == "cql2-json"
459+
else to_cql2(parse_cql2_text(filter_expr))
433460
)
434461

435462
if fields:
@@ -445,12 +472,16 @@ async def get_search(
445472
try:
446473
search_request = self.post_request_model(**base_args)
447474
except ValidationError as e:
448-
raise HTTPException(status_code=400, detail=f"Invalid parameters provided: {e}")
475+
raise HTTPException(
476+
status_code=400, detail=f"Invalid parameters provided: {e}"
477+
)
449478
resp = await self.post_search(search_request=search_request, request=request)
450479

451480
return resp
452481

453-
async def post_search(self, search_request: BaseSearchPostRequest, request: Request) -> stac_types.ItemCollection:
482+
async def post_search(
483+
self, search_request: BaseSearchPostRequest, request: Request
484+
) -> stac_types.ItemCollection:
454485
"""
455486
Perform a POST search on the catalog.
456487
@@ -469,10 +500,14 @@ async def post_search(self, search_request: BaseSearchPostRequest, request: Requ
469500
search = self.database.make_search()
470501

471502
if search_request.ids:
472-
search = self.database.apply_ids_filter(search=search, item_ids=search_request.ids)
503+
search = self.database.apply_ids_filter(
504+
search=search, item_ids=search_request.ids
505+
)
473506

474507
if search_request.collections:
475-
search = self.database.apply_collections_filter(search=search, collection_ids=search_request.collections)
508+
search = self.database.apply_collections_filter(
509+
search=search, collection_ids=search_request.collections
510+
)
476511

477512
try:
478513
search, datetime_search = self.database.apply_datetime_filter(
@@ -492,30 +527,38 @@ async def post_search(self, search_request: BaseSearchPostRequest, request: Requ
492527
search = self.database.apply_bbox_filter(search=search, bbox=bbox)
493528

494529
if search_request.intersects:
495-
search = self.database.apply_intersects_filter(search=search, intersects=search_request.intersects)
530+
search = self.database.apply_intersects_filter(
531+
search=search, intersects=search_request.intersects
532+
)
496533

497534
if search_request.query:
498535
for field_name, expr in search_request.query.items():
499536
field = "properties__" + field_name
500537
for op, value in expr.items():
501538
# Convert enum to string
502539
operator = op.value if isinstance(op, Enum) else op
503-
search = self.database.apply_stacql_filter(search=search, op=operator, field=field, value=value)
540+
search = self.database.apply_stacql_filter(
541+
search=search, op=operator, field=field, value=value
542+
)
504543

505544
# only cql2_json is supported here
506545
if hasattr(search_request, "filter_expr"):
507546
cql2_filter = getattr(search_request, "filter_expr", None)
508547
try:
509548
search = await self.database.apply_cql2_filter(search, cql2_filter)
510549
except Exception as e:
511-
raise HTTPException(status_code=400, detail=f"Error with cql2_json filter: {e}")
550+
raise HTTPException(
551+
status_code=400, detail=f"Error with cql2_json filter: {e}"
552+
)
512553

513554
if hasattr(search_request, "q"):
514555
free_text_queries = getattr(search_request, "q", None)
515556
try:
516557
search = self.database.apply_free_text_filter(search, free_text_queries)
517558
except Exception as e:
518-
raise HTTPException(status_code=400, detail=f"Error with free text query: {e}")
559+
raise HTTPException(
560+
status_code=400, detail=f"Error with free text query: {e}"
561+
)
519562

520563
sort = None
521564
if search_request.sortby:
@@ -534,7 +577,11 @@ async def post_search(self, search_request: BaseSearchPostRequest, request: Requ
534577
datetime_search=datetime_search,
535578
)
536579

537-
fields = getattr(search_request, "fields", None) if self.extension_is_enabled("FieldsExtension") else None
580+
fields = (
581+
getattr(search_request, "fields", None)
582+
if self.extension_is_enabled("FieldsExtension")
583+
else None
584+
)
538585
include: Set[str] = fields.include if fields and fields.include else set()
539586
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()
540587

@@ -593,10 +640,15 @@ async def create_item(
593640

594641
# Handle FeatureCollection (bulk insert)
595642
if item_dict["type"] == "FeatureCollection":
596-
bulk_client = BulkTransactionsClient(database=self.database, settings=self.settings)
643+
bulk_client = BulkTransactionsClient(
644+
database=self.database, settings=self.settings
645+
)
597646
features = item_dict["features"]
598647
processed_items = [
599-
bulk_client.preprocess_item(feature, base_url, BulkTransactionMethod.INSERT) for feature in features
648+
bulk_client.preprocess_item(
649+
feature, base_url, BulkTransactionMethod.INSERT
650+
)
651+
for feature in features
600652
]
601653
attempted = len(processed_items)
602654

@@ -610,15 +662,21 @@ async def create_item(
610662
f"Bulk async operation encountered errors for collection {collection_id}: {errors} (attempted {attempted})"
611663
)
612664
else:
613-
logger.info(f"Bulk async operation succeeded with {success} actions for collection {collection_id}.")
665+
logger.info(
666+
f"Bulk async operation succeeded with {success} actions for collection {collection_id}."
667+
)
614668
return f"Successfully added {success} Items. {attempted - success} errors occurred."
615669

616670
# Handle single item
617-
await self.database.create_item(item_dict, base_url=base_url, exist_ok=False, **kwargs)
671+
await self.database.create_item(
672+
item_dict, base_url=base_url, exist_ok=False, **kwargs
673+
)
618674
return ItemSerializer.db_to_stac(item_dict, base_url)
619675

620676
@overrides
621-
async def update_item(self, collection_id: str, item_id: str, item: Item, **kwargs) -> stac_types.Item:
677+
async def update_item(
678+
self, collection_id: str, item_id: str, item: Item, **kwargs
679+
) -> stac_types.Item:
622680
"""Update an item in the collection.
623681
624682
Args:
@@ -640,7 +698,9 @@ async def update_item(self, collection_id: str, item_id: str, item: Item, **kwar
640698
now = datetime_type.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
641699
item["properties"]["updated"] = now
642700

643-
await self.database.create_item(item, base_url=base_url, exist_ok=True, **kwargs)
701+
await self.database.create_item(
702+
item, base_url=base_url, exist_ok=True, **kwargs
703+
)
644704

645705
return ItemSerializer.db_to_stac(item, base_url)
646706

@@ -697,7 +757,9 @@ async def patch_item(
697757
if item:
698758
return ItemSerializer.db_to_stac(item, base_url=base_url)
699759

700-
raise NotImplementedError(f"Content-Type: {content_type} and body: {patch} combination not implemented")
760+
raise NotImplementedError(
761+
f"Content-Type: {content_type} and body: {patch} combination not implemented"
762+
)
701763

702764
@overrides
703765
async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> None:
@@ -710,11 +772,15 @@ async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> None:
710772
Returns:
711773
None: Returns 204 No Content on successful deletion
712774
"""
713-
await self.database.delete_item(item_id=item_id, collection_id=collection_id, **kwargs)
775+
await self.database.delete_item(
776+
item_id=item_id, collection_id=collection_id, **kwargs
777+
)
714778
return None
715779

716780
@overrides
717-
async def create_collection(self, collection: Collection, **kwargs) -> stac_types.Collection:
781+
async def create_collection(
782+
self, collection: Collection, **kwargs
783+
) -> stac_types.Collection:
718784
"""Create a new collection in the database.
719785
720786
Args:
@@ -739,7 +805,9 @@ async def create_collection(self, collection: Collection, **kwargs) -> stac_type
739805
)
740806

741807
@overrides
742-
async def update_collection(self, collection_id: str, collection: Collection, **kwargs) -> stac_types.Collection:
808+
async def update_collection(
809+
self, collection_id: str, collection: Collection, **kwargs
810+
) -> stac_types.Collection:
743811
"""
744812
Update a collection.
745813
@@ -764,7 +832,9 @@ async def update_collection(self, collection_id: str, collection: Collection, **
764832
request = kwargs["request"]
765833

766834
collection = self.database.collection_serializer.stac_to_db(collection, request)
767-
await self.database.update_collection(collection_id=collection_id, collection=collection, **kwargs)
835+
await self.database.update_collection(
836+
collection_id=collection_id, collection=collection, **kwargs
837+
)
768838

769839
return CollectionSerializer.db_to_stac(
770840
collection,
@@ -821,7 +891,9 @@ async def patch_collection(
821891
extensions=[type(ext).__name__ for ext in self.database.extensions],
822892
)
823893

824-
raise NotImplementedError(f"Content-Type: {content_type} and body: {patch} combination not implemented")
894+
raise NotImplementedError(
895+
f"Content-Type: {content_type} and body: {patch} combination not implemented"
896+
)
825897

826898
@overrides
827899
async def delete_collection(self, collection_id: str, **kwargs) -> None:
@@ -860,7 +932,9 @@ def __attrs_post_init__(self):
860932
"""Create es engine."""
861933
self.client = self.settings.create_client
862934

863-
def preprocess_item(self, item: stac_types.Item, base_url, method: BulkTransactionMethod) -> stac_types.Item:
935+
def preprocess_item(
936+
self, item: stac_types.Item, base_url, method: BulkTransactionMethod
937+
) -> stac_types.Item:
864938
"""Preprocess an item to match the data model.
865939
866940
Args:
@@ -872,10 +946,14 @@ def preprocess_item(self, item: stac_types.Item, base_url, method: BulkTransacti
872946
The preprocessed item.
873947
"""
874948
exist_ok = method == BulkTransactionMethod.UPSERT
875-
return self.database.bulk_sync_prep_create_item(item=item, base_url=base_url, exist_ok=exist_ok)
949+
return self.database.bulk_sync_prep_create_item(
950+
item=item, base_url=base_url, exist_ok=exist_ok
951+
)
876952

877953
@overrides
878-
def bulk_item_insert(self, items: Items, chunk_size: Optional[int] = None, **kwargs) -> str:
954+
def bulk_item_insert(
955+
self, items: Items, chunk_size: Optional[int] = None, **kwargs
956+
) -> str:
879957
"""Perform a bulk insertion of items into the database using Elasticsearch.
880958
881959
Args:
@@ -896,7 +974,11 @@ def bulk_item_insert(self, items: Items, chunk_size: Optional[int] = None, **kwa
896974
for item in items.items.values():
897975
try:
898976
validated = Item(**item) if not isinstance(item, Item) else item
899-
processed_items.append(self.preprocess_item(validated.model_dump(mode="json"), base_url, items.method))
977+
processed_items.append(
978+
self.preprocess_item(
979+
validated.model_dump(mode="json"), base_url, items.method
980+
)
981+
)
900982
except ValidationError:
901983
# Immediately raise on the first invalid item (strict mode)
902984
raise

0 commit comments

Comments
 (0)