Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,13 @@ repos:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
language_version: python
exclude: tests/.*
additional_dependencies:
- types-attrs
- pydantic
10 changes: 10 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@

## [Unreleased]

### Changed

- switched from `attr.s` to `attrs.define` for dataclasses definition.
- Extension's classes now HAVE TO be defined with `@define(slots=False)`
- switched from `attr.id` to `attrs.field` for dataclasses's attributes definition
- remove support of `cql-json` in Filter extension

### Added

- `py.typed` files for each sub-modules
- type checking in `pre-commit`

## [5.2.1] - 2025-04-18

### Fixed
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ section-order = ["future", "standard-library", "third-party", "first-party", "lo
[tool.ruff.format]
quote-style = "double"

[tool.mypy]
ignore_missing_imports = true
namespace_packages = true
explicit_package_bases = true
exclude = ["tests", ".venv"]

[tool.bumpversion]
current_version = "5.2.1"
parse = """(?x)
Expand Down
56 changes: 28 additions & 28 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union

import attr
import attrs
from brotli_asgi import BrotliMiddleware
from fastapi import APIRouter, FastAPI
from fastapi.params import Depends
Expand Down Expand Up @@ -39,7 +39,7 @@
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest


@attr.s
@attrs.define
class StacApi:
"""StacApi factory.

Expand Down Expand Up @@ -75,30 +75,30 @@ class StacApi:

"""

settings: ApiSettings = attr.ib()
client: Union[AsyncBaseCoreClient, BaseCoreClient] = attr.ib()
extensions: List[ApiExtension] = attr.ib(default=attr.Factory(list))
exceptions: Dict[Type[Exception], int] = attr.ib(
default=attr.Factory(lambda: DEFAULT_STATUS_CODES)
settings: ApiSettings = attrs.field()
client: Union[AsyncBaseCoreClient, BaseCoreClient] = attrs.field()
extensions: List[ApiExtension] = attrs.field(factory=list)
exceptions: Dict[Type[Exception], int] = attrs.field(
factory=lambda: DEFAULT_STATUS_CODES
)
title: str = attr.ib(
default=attr.Factory(
title: str = attrs.field(
default=attrs.Factory(
lambda self: self.settings.stac_fastapi_title, takes_self=True
)
)
api_version: str = attr.ib(
default=attr.Factory(
api_version: str = attrs.field(
default=attrs.Factory(
lambda self: self.settings.stac_fastapi_version, takes_self=True
)
)
stac_version: str = attr.ib(default=STAC_VERSION)
description: str = attr.ib(
default=attr.Factory(
stac_version: str = attrs.field(default=STAC_VERSION)
description: str = attrs.field(
default=attrs.Factory(
lambda self: self.settings.stac_fastapi_description, takes_self=True
)
)
app: FastAPI = attr.ib(
default=attr.Factory(
app: FastAPI = attrs.field(
default=attrs.Factory(
lambda self: FastAPI(
openapi_url=self.settings.openapi_url,
docs_url=self.settings.docs_url,
Expand All @@ -112,29 +112,29 @@ class StacApi:
),
converter=update_openapi,
)
router: APIRouter = attr.ib(default=attr.Factory(APIRouter))
search_get_request_model: Type[BaseSearchGetRequest] = attr.ib(
router: APIRouter = attrs.field(default=attrs.Factory(APIRouter))
search_get_request_model: Type[BaseSearchGetRequest] = attrs.field(
default=BaseSearchGetRequest
)
search_post_request_model: Type[BaseSearchPostRequest] = attr.ib(
search_post_request_model: Type[BaseSearchPostRequest] = attrs.field(
default=BaseSearchPostRequest
)
collections_get_request_model: Type[APIRequest] = attr.ib(default=EmptyRequest)
collection_get_request_model: Type[APIRequest] = attr.ib(default=CollectionUri)
items_get_request_model: Type[APIRequest] = attr.ib(default=ItemCollectionUri)
item_get_request_model: Type[APIRequest] = attr.ib(default=ItemUri)
response_class: Type[Response] = attr.ib(default=JSONResponse)
middlewares: List[Middleware] = attr.ib(
default=attr.Factory(
collections_get_request_model: Type[APIRequest] = attrs.field(default=EmptyRequest)
collection_get_request_model: Type[APIRequest] = attrs.field(default=CollectionUri)
items_get_request_model: Type[APIRequest] = attrs.field(default=ItemCollectionUri)
item_get_request_model: Type[APIRequest] = attrs.field(default=ItemUri)
response_class: Type[Response] = attrs.field(default=JSONResponse)
middlewares: List[Middleware] = attrs.field(
default=attrs.Factory(
lambda: [
Middleware(BrotliMiddleware),
Middleware(CORSMiddleware),
Middleware(ProxyHeaderMiddleware),
]
)
)
route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[])
health_check: Union[Callable[[], Dict], Callable[[], Awaitable[Dict]]] = attr.ib(
route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attrs.field(default=[])
health_check: Union[Callable[[], Dict], Callable[[], Awaitable[Dict]]] = attrs.field(
default=lambda: {"status": "UP"}
)

Expand Down
33 changes: 19 additions & 14 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import List, Literal, Optional, Type, Union

import attr
import attrs
from fastapi import Path, Query
from pydantic import BaseModel, create_model
from stac_pydantic.shared import BBox
Expand Down Expand Up @@ -49,14 +49,19 @@ def create_request_model(

# Handle GET requests
if all([issubclass(m, APIRequest) for m in models]):
return attr.make_class(model_name, attrs={}, bases=tuple(models))
return attrs.make_class(
model_name,
attrs={**{}},
bases=tuple(models),
)

# Handle POST requests
elif all([issubclass(m, BaseModel) for m in models]):
for model in models:
for k, field_info in model.model_fields.items():
fields[k] = (field_info.annotation, field_info)
return create_model(model_name, **fields, __base__=base_model)

return create_model(model_name, **fields, __base__=base_model) # type: ignore

raise TypeError("Mixed Request Model types. Check extension request types.")

Expand Down Expand Up @@ -88,41 +93,41 @@ def create_post_request_model(
)


@attr.s
@attrs.define
class CollectionUri(APIRequest):
"""Get or delete collection."""

collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
collection_id: Annotated[str, Path(description="Collection ID")] = attrs.field()


@attr.s
@attrs.define
class ItemUri(APIRequest):
"""Get or delete item."""

collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
item_id: Annotated[str, Path(description="Item ID")] = attr.ib()
collection_id: Annotated[str, Path(description="Collection ID")] = attrs.field()
item_id: Annotated[str, Path(description="Item ID")] = attrs.field()


@attr.s
@attrs.define
class EmptyRequest(APIRequest):
"""Empty request."""

...


@attr.s
@attrs.define
class ItemCollectionUri(APIRequest, DatetimeMixin):
"""Get item collection."""

collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
collection_id: Annotated[str, Path(description="Collection ID")] = attrs.field()
limit: Annotated[
Optional[Limit],
Query(
description="Limits the number of results that are included in each page of the response (capped to 10_000)." # noqa: E501
),
] = attr.ib(default=10)
bbox: Optional[BBox] = attr.ib(default=None, converter=_bbox_converter)
datetime: DateTimeQueryType = attr.ib(default=None, validator=_validate_datetime)
] = attrs.field(default=10)
bbox: Optional[BBox] = attrs.field(default=None, converter=_bbox_converter)
datetime: DateTimeQueryType = attrs.field(default=None, validator=_validate_datetime)


class GeoJSONResponse(JSONResponse):
Expand Down
Empty file.
26 changes: 11 additions & 15 deletions stac_fastapi/api/stac_fastapi/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, TypedDict, Union

from fastapi import Depends, FastAPI, params
from fastapi.datastructures import DefaultPlaceholder
Expand Down Expand Up @@ -48,30 +48,26 @@ def create_async_endpoint(
if not inspect.iscoroutinefunction(func):
func = sync_to_async(func)

if issubclass(request_model, APIRequest):
_endpoint: Callable[[Any, Any], Awaitable[Any]]

if isinstance(request_model, dict):

async def _endpoint(
request: Request,
request_data: request_model = Depends(), # type:ignore
request_data: Dict[str, Any],
):
"""Endpoint."""
return _wrap_response(await func(request=request, **request_data.kwargs()))
return _wrap_response(await func(request_data, request=request))

elif issubclass(request_model, BaseModel):
elif issubclass(request_model, APIRequest):

async def _endpoint(
request: Request,
request_data: request_model, # type:ignore
):
async def _endpoint(request: Request, request_data=Depends(request_model)):
"""Endpoint."""
return _wrap_response(await func(request_data, request=request))
return _wrap_response(await func(request=request, **request_data.kwargs()))

else:
elif issubclass(request_model, BaseModel):

async def _endpoint(
request: Request,
request_data: Dict[str, Any], # type:ignore
):
async def _endpoint(request: Request, request_data: request_model): # type: ignore
"""Endpoint."""
return _wrap_response(await func(request_data, request=request))

Expand Down
26 changes: 13 additions & 13 deletions stac_fastapi/api/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Union

import attr
import attrs
import pytest
from fastapi import Path, Query
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -385,25 +385,25 @@ def item_collection(
def test_request_model(AsyncTestCoreClient):
"""Test if request models are passed correctly."""

@attr.s
@attrs.define
class CollectionsRequest(APIRequest):
user: Annotated[str, Query(...)] = attr.ib()
user: Annotated[str, Query(...)] = attrs.field()

@attr.s
@attrs.define
class CollectionRequest(APIRequest):
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
user: Annotated[str, Query(...)] = attr.ib()
collection_id: Annotated[str, Path(description="Collection ID")] = attrs.field()
user: Annotated[str, Query(...)] = attrs.field()

@attr.s
@attrs.define
class ItemsRequest(APIRequest):
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
user: Annotated[str, Query(...)] = attr.ib()
collection_id: Annotated[str, Path(description="Collection ID")] = attrs.field()
user: Annotated[str, Query(...)] = attrs.field()

@attr.s
@attrs.define
class ItemRequest(APIRequest):
collection_id: Annotated[str, Path(description="Collection ID")] = attr.ib()
item_id: Annotated[str, Path(description="Item ID")] = attr.ib()
user: Annotated[str, Query(...)] = attr.ib()
collection_id: Annotated[str, Path(description="Collection ID")] = attrs.field()
item_id: Annotated[str, Path(description="Item ID")] = attrs.field()
user: Annotated[str, Query(...)] = attrs.field()

test_app = app.StacApi(
settings=ApiSettings(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import List, Union

import attr
import attrs
from fastapi import APIRouter, FastAPI

from stac_fastapi.api.models import CollectionUri, EmptyRequest
Expand All @@ -23,7 +23,7 @@ class AggregationConformanceClasses(str, Enum):
AGGREGATION = "https://api.stacspec.org/v0.3.0/aggregation"


@attr.s
@attrs.define
class AggregationExtension(ApiExtension):
"""Aggregation Extension.

Expand Down Expand Up @@ -53,14 +53,14 @@ class AggregationExtension(ApiExtension):
GET = AggregationExtensionGetRequest
POST = AggregationExtensionPostRequest

client: Union[AsyncBaseAggregationClient, BaseAggregationClient] = attr.ib(
client: Union[AsyncBaseAggregationClient, BaseAggregationClient] = attrs.field(
factory=BaseAggregationClient
)

conformance_classes: List[str] = attr.ib(
default=[AggregationConformanceClasses.AGGREGATION]
conformance_classes: List[str] = attrs.field(
default=[AggregationConformanceClasses.AGGREGATION.value]
)
router: APIRouter = attr.ib(factory=APIRouter)
router: APIRouter = attrs.field(factory=APIRouter)

def register(self, app: FastAPI) -> None:
"""Register the extension with a FastAPI application.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
from typing import List, Optional, Union

import attr
import attrs
from geojson_pydantic.geometries import Geometry
from stac_pydantic.shared import BBox

Expand All @@ -12,7 +12,7 @@
from .types import Aggregation, AggregationCollection


@attr.s
@attrs.define
class BaseAggregationClient(abc.ABC):
"""Defines a pattern for implementing the STAC aggregation extension."""

Expand Down Expand Up @@ -67,7 +67,7 @@ def aggregate(
)


@attr.s
@attrs.define
class AsyncBaseAggregationClient(abc.ABC):
"""Defines an async pattern for implementing the STAC aggregation extension."""

Expand Down
Loading