Skip to content

Commit 99776ce

Browse files
Yuri ZmytrakovYuri Zmytrakov
authored andcommitted
fix: Add relations in OS STAC pagination
This commit adds missing relations to OS STAC pagination: - `collections/{collection_id}/items`: added `collection` and `parent` relations to the endpoint. - `search`, `collections/{collection_id}/items`, `collections`: added the `previous` relation when following the `next` link.
1 parent 8974c38 commit 99776ce

File tree

4 files changed

+111
-19
lines changed

4 files changed

+111
-19
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ async def all_collections(self, **kwargs) -> stac_types.Collections:
238238
limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10)))
239239
token = request.query_params.get("token")
240240

241-
collections, next_token = await self.database.get_all_collections(
241+
collections, next_token, prev_token = await self.database.get_all_collections(
242242
token=token, limit=limit, request=request
243243
)
244244

@@ -256,6 +256,10 @@ async def all_collections(self, **kwargs) -> stac_types.Collections:
256256
next_link = PagingLinks(next=next_token, request=request).link_next()
257257
links.append(next_link)
258258

259+
if prev_token:
260+
prev_token = PagingLinks(prev=prev_token, request=request).link_previous()
261+
links.append(prev_token)
262+
259263
return stac_types.Collections(collections=collections, links=links)
260264

261265
async def get_collection(
@@ -343,7 +347,7 @@ async def item_collection(
343347
search = self.database.apply_bbox_filter(search=search, bbox=bbox)
344348

345349
limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10)))
346-
items, maybe_count, next_token = await self.database.execute_search(
350+
items, maybe_count, next_token, prev_token = await self.database.execute_search(
347351
search=search,
348352
limit=limit,
349353
sort=None,
@@ -356,7 +360,12 @@ async def item_collection(
356360
self.item_serializer.db_to_stac(item, base_url=base_url) for item in items
357361
]
358362

359-
links = await PagingLinks(request=request, next=next_token).get_links()
363+
links = await PagingLinks(
364+
request=request,
365+
next=next_token,
366+
prev=prev_token,
367+
collection_id=collection_id,
368+
).get_links()
360369

361370
return stac_types.ItemCollection(
362371
type="FeatureCollection",
@@ -568,7 +577,7 @@ async def post_search(
568577
if search_request.limit:
569578
limit = search_request.limit
570579

571-
items, maybe_count, next_token = await self.database.execute_search(
580+
items, maybe_count, next_token, prev_token = await self.database.execute_search(
572581
search=search,
573582
limit=limit,
574583
token=search_request.token,
@@ -593,7 +602,9 @@ async def post_search(
593602
)
594603
for item in items
595604
]
596-
links = await PagingLinks(request=request, next=next_token).get_links()
605+
links = await PagingLinks(
606+
request=request, next=next_token, prev=prev_token
607+
).get_links()
597608

598609
return stac_types.ItemCollection(
599610
type="FeatureCollection",

stac_fastapi/core/stac_fastapi/core/models/links.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ class PagingLinks(BaseLinks):
179179
"""Create links for paging."""
180180

181181
next: Optional[str] = attr.ib(kw_only=True, default=None)
182+
prev: Optional[str] = attr.ib(kw_only=True, default=None)
183+
collection_id: Optional[str] = attr.ib(default=None)
182184

183185
def link_next(self) -> Optional[Dict[str, Any]]:
184186
"""Create link for next page."""
@@ -203,3 +205,48 @@ def link_next(self) -> Optional[Dict[str, Any]]:
203205
}
204206

205207
return None
208+
209+
def link_previous(self) -> Optional[Dict[str, Any]]:
210+
"""Create link for previous page."""
211+
if not self.prev:
212+
return None
213+
214+
method = self.request.method
215+
if method == "GET":
216+
href = merge_params(self.url, {"token": self.prev})
217+
link = dict(
218+
rel=Relations.previous.value,
219+
type=MimeTypes.json.value,
220+
method=method,
221+
href=href,
222+
)
223+
return link
224+
if method == "POST":
225+
return {
226+
"rel": Relations.previous,
227+
"type": MimeTypes.json.value,
228+
"method": method,
229+
"href": f"{self.request.url}",
230+
"body": {**self.request.postbody, "token": self.prev},
231+
}
232+
return None
233+
234+
def link_parent(self) -> Optional[Dict[str, Any]]:
235+
"""Create the `parent` link if collection_id is provided."""
236+
if self.collection_id:
237+
return dict(
238+
rel=Relations.parent.value,
239+
type=MimeTypes.json.value,
240+
href=urljoin(self.base_url, f"collections/{self.collection_id}"),
241+
)
242+
return None
243+
244+
def link_collection(self) -> Optional[Dict[str, Any]]:
245+
"""Create the `collection` link if collection_id is provided."""
246+
if self.collection_id:
247+
return dict(
248+
rel=Relations.collection.value,
249+
type=MimeTypes.json.value,
250+
href=urljoin(self.base_url, f"collections/{self.collection_id}"),
251+
)
252+
return None

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __attrs_post_init__(self):
171171

172172
async def get_all_collections(
173173
self, token: Optional[str], limit: int, request: Request
174-
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
174+
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[str]]:
175175
"""Retrieve a list of all collections from Elasticsearch, supporting pagination.
176176
177177
Args:
@@ -189,12 +189,17 @@ async def get_all_collections(
189189
index=COLLECTIONS_INDEX,
190190
body={
191191
"sort": [{"id": {"order": "asc"}}],
192-
"size": limit,
192+
"size": limit + 1,
193193
**({"search_after": search_after} if search_after is not None else {}),
194194
},
195195
)
196196

197197
hits = response["hits"]["hits"]
198+
199+
has_more = len(hits) > limit
200+
if has_more:
201+
hits = hits[:limit]
202+
198203
collections = [
199204
self.collection_serializer.db_to_stac(
200205
collection=hit["_source"], request=request, extensions=self.extensions
@@ -206,7 +211,11 @@ async def get_all_collections(
206211
if len(hits) == limit:
207212
next_token = hits[-1]["sort"][0]
208213

209-
return collections, next_token
214+
prev_token = None
215+
if token:
216+
prev_token = token
217+
218+
return collections, next_token, prev_token
210219

211220
async def get_one_item(self, collection_id: str, item_id: str) -> Dict:
212221
"""Retrieve a single item from the database.
@@ -506,7 +515,7 @@ async def execute_search(
506515
collection_ids: Optional[List[str]],
507516
datetime_search: Dict[str, Optional[str]],
508517
ignore_unavailable: bool = True,
509-
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]:
518+
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str], Optional[str]]:
510519
"""Execute a search query with limit and other optional parameters.
511520
512521
Args:
@@ -579,6 +588,10 @@ async def execute_search(
579588
if hits and (sort_array := hits[limit - 1].get("sort")):
580589
next_token = urlsafe_b64encode(orjson.dumps(sort_array)).decode()
581590

591+
prev_token = None
592+
if token and hits:
593+
if hits and (first_item_sort := hits[0].get("sort")):
594+
prev_token = urlsafe_b64encode(orjson.dumps(first_item_sort)).decode()
582595
matched = (
583596
es_response["hits"]["total"]["value"]
584597
if es_response["hits"]["total"]["relation"] == "eq"
@@ -590,7 +603,7 @@ async def execute_search(
590603
except Exception as e:
591604
logger.error(f"Count task failed: {e}")
592605

593-
return items, matched, next_token
606+
return items, matched, next_token, prev_token
594607

595608
""" AGGREGATE LOGIC """
596609

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __attrs_post_init__(self):
154154

155155
async def get_all_collections(
156156
self, token: Optional[str], limit: int, request: Request
157-
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
157+
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[str]]:
158158
"""
159159
Retrieve a list of all collections from Opensearch, supporting pagination.
160160
@@ -167,12 +167,16 @@ async def get_all_collections(
167167
"""
168168
search_body = {
169169
"sort": [{"id": {"order": "asc"}}],
170-
"size": limit,
170+
"size": limit + 1,
171171
}
172172

173-
# Only add search_after to the query if token is not None and not empty
173+
search_after = None
174174
if token:
175-
search_after = [token]
175+
try:
176+
search_after = orjson.loads(urlsafe_b64decode(token))
177+
except (ValueError, TypeError, orjson.JSONDecodeError):
178+
search_after = [token]
179+
if search_after:
176180
search_body["search_after"] = search_after
177181

178182
response = await self.client.search(
@@ -181,6 +185,9 @@ async def get_all_collections(
181185
)
182186

183187
hits = response["hits"]["hits"]
188+
has_more = len(hits) > limit
189+
if has_more:
190+
hits = hits[:limit]
184191
collections = [
185192
self.collection_serializer.db_to_stac(
186193
collection=hit["_source"], request=request, extensions=self.extensions
@@ -189,13 +196,21 @@ async def get_all_collections(
189196
]
190197

191198
next_token = None
192-
if len(hits) == limit:
199+
if has_more and hits:
193200
# Ensure we have a valid sort value for next_token
194201
next_token_values = hits[-1].get("sort")
195202
if next_token_values:
196-
next_token = next_token_values[0]
203+
next_token = urlsafe_b64encode(orjson.dumps(next_token_values)).decode()
197204

198-
return collections, next_token
205+
prev_token = None
206+
if token and hits:
207+
first_token_values = hits[0].get("sort")
208+
if first_token_values:
209+
prev_token = urlsafe_b64encode(
210+
orjson.dumps(first_token_values)
211+
).decode()
212+
213+
return collections, next_token, prev_token
199214

200215
async def get_one_item(self, collection_id: str, item_id: str) -> Dict:
201216
"""Retrieve a single item from the database.
@@ -495,7 +510,7 @@ async def execute_search(
495510
collection_ids: Optional[List[str]],
496511
datetime_search: Dict[str, Optional[str]],
497512
ignore_unavailable: bool = True,
498-
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]:
513+
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str], Optional[str]]:
499514
"""Execute a search query with limit and other optional parameters.
500515
501516
Args:
@@ -514,6 +529,7 @@ async def execute_search(
514529
- The total number of results (if the count could be computed), or None if the count could not be
515530
computed.
516531
- The token to be used to retrieve the next set of results, or None if there are no more results.
532+
- The token to be used to retrieve the previous set of results, or None if this is the first page.
517533
518534
Raises:
519535
NotFoundError: If the collections specified in `collection_ids` do not exist.
@@ -570,10 +586,15 @@ async def execute_search(
570586
items = (hit["_source"] for hit in hits[:limit])
571587

572588
next_token = None
589+
prev_token = None
573590
if len(hits) > limit and limit < max_result_window:
574591
if hits and (sort_array := hits[limit - 1].get("sort")):
575592
next_token = urlsafe_b64encode(orjson.dumps(sort_array)).decode()
576593

594+
if token and hits:
595+
if hits and (first_item_sort := hits[0].get("sort")):
596+
prev_token = urlsafe_b64encode(orjson.dumps(first_item_sort)).decode()
597+
577598
matched = (
578599
es_response["hits"]["total"]["value"]
579600
if es_response["hits"]["total"]["relation"] == "eq"
@@ -585,7 +606,7 @@ async def execute_search(
585606
except Exception as e:
586607
logger.error(f"Count task failed: {e}")
587608

588-
return items, matched, next_token
609+
return items, matched, next_token, prev_token
589610

590611
""" AGGREGATE LOGIC """
591612

0 commit comments

Comments
 (0)