Skip to content

Commit bd6c15f

Browse files
committed
refactor extension map and add filter extension for item_collection
1 parent af98395 commit bd6c15f

File tree

2 files changed

+71
-50
lines changed

2 files changed

+71
-50
lines changed

stac_fastapi/pgstac/app.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,28 @@
4343
from stac_fastapi.pgstac.types.search import PgstacSearch
4444

4545
settings = Settings()
46-
extensions_map = {
46+
47+
# transaction extensions
48+
trans_extensions_map = {
4749
"transaction": TransactionExtension(
4850
client=TransactionsClient(),
4951
settings=settings,
5052
response_class=ORJSONResponse,
5153
),
54+
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
55+
}
56+
57+
# search extensions
58+
search_extensions_map = {
5259
"query": QueryExtension(),
5360
"sort": SortExtension(),
5461
"fields": FieldsExtension(),
55-
"pagination": TokenPaginationExtension(),
5662
"filter": FilterExtension(client=FiltersClient()),
57-
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
63+
"pagination": TokenPaginationExtension(),
5864
}
5965

60-
# some extensions are supported in combination with the collection search extension
61-
collection_extensions_map = {
66+
# collection_search extensions
67+
cs_extensions_map = {
6268
"query": QueryExtension(),
6369
"sort": SortExtension(),
6470
"fields": FieldsExtension(),
@@ -67,44 +73,68 @@
6773
"pagination": OffsetPaginationExtension(),
6874
}
6975

76+
# item_collection extensions
77+
itm_col_extensions_map = {
78+
"filter": FilterExtension(client=FiltersClient()),
79+
"pagination": TokenPaginationExtension(),
80+
}
81+
82+
known_extensions = {
83+
*trans_extensions_map.keys(),
84+
*search_extensions_map.keys(),
85+
*cs_extensions_map.keys(),
86+
*itm_col_extensions_map.keys(),
87+
"collection_search",
88+
}
89+
7090
enabled_extensions = (
7191
os.environ["ENABLED_EXTENSIONS"].split(",")
7292
if "ENABLED_EXTENSIONS" in os.environ
73-
else list(extensions_map.keys()) + ["collection_search"]
93+
else known_extensions
7494
)
75-
extensions = [
76-
extension for key, extension in extensions_map.items() if key in enabled_extensions
95+
96+
application_extensions = [
97+
extension
98+
for key, extension in trans_extensions_map.items()
99+
if key in enabled_extensions
77100
]
78101

79-
items_get_request_model = (
80-
create_request_model(
102+
# /search models
103+
search_extensions = [
104+
extension
105+
for key, extension in search_extensions_map.items()
106+
if key in enabled_extensions
107+
]
108+
post_request_model = create_post_request_model(search_extensions, base_model=PgstacSearch)
109+
get_request_model = create_get_request_model(search_extensions)
110+
application_extensions.extend(search_extensions)
111+
112+
# /collections/{collectionId}/items model
113+
items_get_request_model = ItemCollectionUri
114+
itm_col_extensions = [
115+
extension
116+
for key, extension in itm_col_extensions_map.items()
117+
if key in enabled_extensions
118+
]
119+
if itm_col_extensions:
120+
items_get_request_model = create_request_model(
81121
model_name="ItemCollectionUri",
82122
base_model=ItemCollectionUri,
83-
mixins=[TokenPaginationExtension().GET],
123+
extensions=itm_col_extensions,
84124
request_type="GET",
85125
)
86-
if any(isinstance(ext, TokenPaginationExtension) for ext in extensions)
87-
else ItemCollectionUri
88-
)
89-
90-
collection_search_extension = (
91-
CollectionSearchExtension.from_extensions(
92-
[
93-
extension
94-
for key, extension in collection_extensions_map.items()
95-
if key in enabled_extensions
96-
]
97-
)
98-
if "collection_search" in enabled_extensions
99-
else None
100-
)
101-
102-
collections_get_request_model = (
103-
collection_search_extension.GET if collection_search_extension else EmptyRequest
104-
)
105126

106-
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
107-
get_request_model = create_get_request_model(extensions)
127+
# /collections model
128+
collections_get_request_model = EmptyRequest
129+
if "collection_search" in enabled_extensions:
130+
cs_extensions = [
131+
extension
132+
for key, extension in cs_extensions_map.items()
133+
if key in enabled_extensions
134+
]
135+
collection_search_extension = CollectionSearchExtension.from_extensions(cs_extensions)
136+
collections_get_request_model = collection_search_extension.GET
137+
application_extensions.append(collection_search_extension)
108138

109139

110140
@asynccontextmanager
@@ -127,9 +157,7 @@ async def lifespan(app: FastAPI):
127157
api = StacApi(
128158
app=update_openapi(fastapp),
129159
settings=settings,
130-
extensions=extensions + [collection_search_extension]
131-
if collection_search_extension
132-
else extensions,
160+
extensions=application_extensions,
133161
client=CoreCrudClient(pgstac_search_model=post_request_model),
134162
response_class=ORJSONResponse,
135163
items_get_request_model=items_get_request_model,

stac_fastapi/pgstac/core.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,10 @@ async def item_collection(
342342
bbox: Optional[BBox] = None,
343343
datetime: Optional[str] = None,
344344
limit: Optional[int] = None,
345+
# Extensions
345346
token: Optional[str] = None,
347+
filter_expr: Optional[str] = None,
348+
filter_lang: Optional[str] = None,
346349
**kwargs,
347350
) -> ItemCollection:
348351
"""Get all items from a specific collection.
@@ -368,21 +371,11 @@ async def item_collection(
368371
"token": token,
369372
}
370373

371-
if self.extension_is_enabled("FilterExtension"):
372-
filter_lang = kwargs.get("filter_lang", None)
373-
filter_query = kwargs.get("filter_expr", None)
374-
if filter_query:
375-
if filter_lang == "cql2-text":
376-
filter_query = to_cql2(parse_cql2_text(filter_query))
377-
filter_lang = "cql2-json"
378-
379-
base_args["filter"] = orjson.loads(filter_query)
380-
base_args["filter-lang"] = filter_lang
381-
382-
clean = {}
383-
for k, v in base_args.items():
384-
if v is not None and v != []:
385-
clean[k] = v
374+
clean = self._clean_search_args(
375+
base_args=base_args,
376+
filter_query=filter_expr,
377+
filter_lang=filter_lang,
378+
)
386379

387380
search_request = self.pgstac_search_model(**clean)
388381
item_collection = await self._search_base(search_request, request=request)

0 commit comments

Comments
 (0)