Skip to content

Commit 71e79cb

Browse files
authored
feat: moar extensions (#24)
* --wip-- [skip ci] * feat: fields * feat: add sortby * feat: advertise extensions * refactor: constant
1 parent c34cddb commit 71e79cb

File tree

3 files changed

+82
-19
lines changed

3 files changed

+82
-19
lines changed

src/stac_fastapi/geoparquet/api.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,29 @@
1111

1212
import stac_fastapi.api.models
1313
from stac_fastapi.api.app import StacApi
14+
from stac_fastapi.extensions.core.fields import FieldsExtension
1415
from stac_fastapi.extensions.core.filter import SearchFilterExtension
1516
from stac_fastapi.extensions.core.pagination import OffsetPaginationExtension
17+
from stac_fastapi.extensions.core.sort import SortExtension
1618
from stac_fastapi.types.search import BaseSearchPostRequest
1719

1820
from .client import Client
1921
from .search import FixedSearchGetRequest
2022
from .settings import Settings
2123

2224
GEOPARQUET_MEDIA_TYPE = "application/vnd.apache.parquet"
23-
24-
GetSearchRequestModel = stac_fastapi.api.models.create_request_model(
25-
model_name="SearchGetRequest",
26-
base_model=FixedSearchGetRequest,
27-
extensions=[SearchFilterExtension()],
28-
mixins=[OffsetPaginationExtension().GET],
29-
request_type="GET",
25+
EXTENSIONS = [
26+
OffsetPaginationExtension(),
27+
SearchFilterExtension(),
28+
FieldsExtension(),
29+
SortExtension(),
30+
]
31+
32+
GetSearchRequestModel = stac_fastapi.api.models.create_get_request_model(
33+
base_model=FixedSearchGetRequest, extensions=EXTENSIONS
3034
)
31-
PostSearchRequestModel = stac_fastapi.api.models.create_request_model(
32-
model_name="SearchPostRequest",
33-
base_model=BaseSearchPostRequest,
34-
extensions=[SearchFilterExtension()],
35-
mixins=[OffsetPaginationExtension().POST],
36-
request_type="POST",
35+
PostSearchRequestModel = stac_fastapi.api.models.create_post_request_model(
36+
base_model=BaseSearchPostRequest, extensions=EXTENSIONS
3737
)
3838

3939

@@ -135,6 +135,7 @@ def create(
135135
),
136136
search_get_request_model=GetSearchRequestModel,
137137
search_post_request_model=PostSearchRequestModel,
138+
extensions=EXTENSIONS,
138139
)
139140
return api
140141

src/stac_fastapi/geoparquet/client.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,38 @@ def search(
165165
collections = list(hrefs.keys())
166166

167167
search_dict = search.model_dump(exclude_none=True, by_alias=True)
168-
kwargs.pop("filter_crs", None)
169-
if filter_expr := kwargs.pop("filter_expr", None):
170-
kwargs["filter"] = filter_expr
171-
if "filter" not in kwargs:
172-
kwargs.pop("filter_lang", None)
173168
search_dict.update(**kwargs)
174169

170+
search_dict.pop("filter_crs", None)
171+
if filter_expr := search_dict.pop("filter_expr", None):
172+
search_dict["filter"] = filter_expr
173+
if filter_lang := search_dict.pop("filter_lang", None):
174+
search_dict["filter-lang"] = filter_lang
175+
if "filter" not in search_dict:
176+
search_dict.pop("filter_lang", None)
177+
search_dict.pop("filter-lang", None)
178+
if fields := search_dict.pop("fields", None):
179+
if isinstance(fields, list):
180+
include = []
181+
exclude = []
182+
for field in fields:
183+
if field.startswith("-"):
184+
exclude.append(field)
185+
else:
186+
include.append(field)
187+
search_dict.update({"include": include, "exclude": exclude})
188+
elif isinstance(fields, dict):
189+
search_dict.update(
190+
{
191+
"include": list(fields.get("include", [])),
192+
"exclude": list(fields.get("exclude", [])),
193+
}
194+
)
195+
else:
196+
raise HTTPException(400, f"unexpected fields type: {fields}")
197+
if sortby := search_dict.pop("sortby", None):
198+
search_dict["sortby"] = sortby
199+
175200
limit = search_dict.get("limit", DEFAULT_LIMIT)
176201
offset = search_dict.get("offset", 0) or 0
177202
items: list[dict[str, Any]] = []

tests/test_search.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,45 @@ def test_paging_filter(client: TestClient) -> None:
123123
"offset": ["1"],
124124
"collections": ["naip,naip-10,openaerialmap-10,openaerialmap"],
125125
"filter": ["naip:year='2022'"],
126-
"filter_lang": ["cql2-text"],
126+
"filter-lang": ["cql2-text"],
127127
}
128128
response = client.get("/search", params=url.query)
129129
assert response.status_code == 200
130130
assert response.json()["features"][0]["id"] == "ne_m_4110263_sw_13_060_20220820"
131+
132+
133+
def test_fields_get(client: TestClient) -> None:
134+
response = client.get(
135+
"/search", params={"collections": "naip", "limit": "1", "fields": "id,geometry"}
136+
)
137+
response.raise_for_status()
138+
data = response.json()
139+
assert "properties" not in data["features"][0]
140+
141+
142+
def test_fields_post(client: TestClient) -> None:
143+
response = client.post(
144+
"/search",
145+
json={
146+
"collections": ["naip"],
147+
"limit": "1",
148+
"fields": {"include": ["id", "geometry"]},
149+
},
150+
)
151+
response.raise_for_status()
152+
data = response.json()
153+
assert "properties" not in data["features"][0]
154+
155+
156+
def test_sort_get(client: TestClient) -> None:
157+
response = client.get("/search", params={"limit": "1", "sortby": "datetime"})
158+
response.raise_for_status()
159+
160+
161+
def test_sort_post(client: TestClient) -> None:
162+
response = client.post(
163+
"/search",
164+
json={"limit": "1", "sortby": [{"field": "datetime", "direction": "asc"}]},
165+
)
166+
print(response.json())
167+
response.raise_for_status()

0 commit comments

Comments
 (0)