|
1 | 1 | """Collection-Search extension.""" |
2 | 2 |
|
3 | 3 | from enum import Enum |
4 | | -from typing import List, Optional, Union |
| 4 | +from typing import List, Optional, Type, Union |
5 | 5 |
|
6 | 6 | import attr |
7 | 7 | from fastapi import APIRouter, FastAPI |
| 8 | +from pydantic import BaseModel |
8 | 9 | from stac_pydantic.api.collections import Collections |
9 | 10 | from stac_pydantic.shared import MimeTypes |
10 | 11 |
|
11 | 12 | from stac_fastapi.api.models import GeoJSONResponse, create_request_model |
12 | 13 | from stac_fastapi.api.routes import create_async_endpoint |
13 | 14 | from stac_fastapi.types.config import ApiSettings |
14 | 15 | from stac_fastapi.types.extension import ApiExtension |
| 16 | +from stac_fastapi.types.search import APIRequest |
15 | 17 |
|
16 | 18 | from .client import AsyncBaseCollectionSearchClient, BaseCollectionSearchClient |
17 | 19 | from .request import BaseCollectionSearchGetRequest, BaseCollectionSearchPostRequest |
@@ -47,8 +49,8 @@ class CollectionSearchExtension(ApiExtension): |
47 | 49 | the extension |
48 | 50 | """ |
49 | 51 |
|
50 | | - GET: BaseCollectionSearchGetRequest = attr.ib(default=BaseCollectionSearchGetRequest) # type: ignore |
51 | | - POST = attr.ib(init=False) |
| 52 | + GET: Type[APIRequest] = attr.ib(default=BaseCollectionSearchGetRequest) |
| 53 | + POST: Optional[Type[BaseModel]] = attr.ib(init=False) |
52 | 54 |
|
53 | 55 | conformance_classes: List[str] = attr.ib( |
54 | 56 | default=[ |
@@ -93,7 +95,7 @@ def from_extensions( |
93 | 95 | ) |
94 | 96 |
|
95 | 97 | return cls( |
96 | | - GET=get_request_model, |
| 98 | + GET=get_request_model, # type: ignore |
97 | 99 | conformance_classes=conformance_classes, |
98 | 100 | schema_href=schema_href, |
99 | 101 | ) |
@@ -127,10 +129,8 @@ class CollectionSearchPostExtension(CollectionSearchExtension): |
127 | 129 | schema_href: Optional[str] = attr.ib(default=None) |
128 | 130 | router: APIRouter = attr.ib(factory=APIRouter) |
129 | 131 |
|
130 | | - GET: BaseCollectionSearchGetRequest = attr.ib(default=BaseCollectionSearchGetRequest) # type: ignore |
131 | | - POST: BaseCollectionSearchPostRequest = attr.ib( # type: ignore |
132 | | - default=BaseCollectionSearchPostRequest |
133 | | - ) |
| 132 | + GET: Type[APIRequest] = attr.ib(default=BaseCollectionSearchGetRequest) |
| 133 | + POST: Type[BaseModel] = attr.ib(default=BaseCollectionSearchPostRequest) |
134 | 134 |
|
135 | 135 | def register(self, app: FastAPI) -> None: |
136 | 136 | """Register the extension with a FastAPI application. |
@@ -198,8 +198,8 @@ def from_extensions( # type: ignore |
198 | 198 | return cls( |
199 | 199 | client=client, |
200 | 200 | settings=settings, |
201 | | - GET=get_request_model, |
202 | | - POST=post_request_model, |
| 201 | + GET=get_request_model, # type: ignore |
| 202 | + POST=post_request_model, # type: ignore |
203 | 203 | conformance_classes=conformance_classes, |
204 | 204 | router=router or APIRouter(), |
205 | 205 | schema_href=schema_href, |
|
0 commit comments