Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,12 @@
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)

# will only use parameters defined in collections_get_request_model
collection_search_model = create_post_request_model(extensions, base_model=PgstacSearch)

api = StacApi(
settings=settings,
extensions=extensions + [collection_search_extension],
client=CoreCrudClient(
post_request_model=post_request_model, # type: ignore
collection_request_model=collection_search_model, # type: ignore
collections_get_request_model=collections_get_request_model, # type: ignore
),
response_class=ORJSONResponse,
items_get_request_model=items_get_request_model,
Expand Down
20 changes: 13 additions & 7 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Item crud client."""

import json
import re
from typing import Any, Dict, List, Optional, Set, Union
from urllib.parse import unquote_plus, urljoin
Expand All @@ -13,7 +14,7 @@
from pygeofilter.backends.cql2_json import to_cql2
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
from pypgstac.hydration import hydrate
from stac_fastapi.api.models import JSONResponse
from stac_fastapi.api.models import APIRequest, EmptyRequest, JSONResponse
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
from stac_fastapi.types.requests import get_base_url
Expand All @@ -38,7 +39,7 @@
class CoreCrudClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""

collection_request_model = attr.ib(default=PgstacSearch)
collections_get_request_model: APIRequest = attr.ib(default=EmptyRequest)

async def all_collections( # noqa: C901
self,
Expand Down Expand Up @@ -83,7 +84,8 @@ async def all_collections( # noqa: C901

# Do the request
try:
search_request = self.collection_request_model(**clean)
search_request = self.collections_get_request_model(**clean)
print(search_request)
except ValidationError as e:
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
Expand All @@ -93,7 +95,7 @@ async def all_collections( # noqa: C901

async def _collection_search_base( # noqa: C901
self,
search_request: PgstacSearch,
search_request: APIRequest,
request: Request,
) -> Collections:
"""Cross catalog search (GET).
Expand All @@ -107,8 +109,12 @@ async def _collection_search_base( # noqa: C901
All collections which match the search criteria.
"""
base_url = get_base_url(request)
search_request_json = search_request.model_dump_json(
exclude_none=True, by_alias=True
search_request_json = json.dumps(
{
key: value
for key, value in search_request.__dict__.items()
if value is not None
}
)

try:
Expand Down Expand Up @@ -533,7 +539,7 @@ def clean_search_args( # noqa: C901
filter_lang = "cql2-json"

base_args["filter"] = orjson.loads(filter_query)
base_args["filter-lang"] = filter_lang
base_args["filter_lang"] = filter_lang

if datetime:
base_args["datetime"] = format_datetime_range(datetime)
Expand Down
10 changes: 4 additions & 6 deletions tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,23 +730,21 @@ async def get_collection(
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)
collection_search_model = create_post_request_model(
extensions, base_model=PgstacSearch
)

collection_search_extension = CollectionSearchExtension.from_extensions(
extensions=extensions
)
collections_get_request_model = collection_search_extension.GET

api = StacApi(
client=Client(
post_request_model=post_request_model,
collection_request_model=collection_search_model,
collections_get_request_model=collection_search_extension.GET,
),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
search_get_request_model=get_request_model,
collections_get_request_model=collections_get_request_model,
collections_get_request_model=collection_search_extension.GET,
)
app = api.app
await connect_to_db(app)
Expand Down
5 changes: 1 addition & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,13 @@ def api_client(request, database):
)

collections_get_request_model = collection_search_extension.GET
collection_search_model = create_post_request_model(
extensions, base_model=PgstacSearch
)

api = StacApi(
settings=api_settings,
extensions=extensions + [collection_search_extension],
client=CoreCrudClient(
post_request_model=search_post_request_model,
collection_request_model=collection_search_model,
collections_get_request_model=collections_get_request_model,
),
items_get_request_model=items_get_request_model,
search_get_request_model=search_get_request_model,
Expand Down
Loading