Skip to content

Commit 985fb96

Browse files
Yuri ZmytrakovYuri Zmytrakov
authored andcommitted
test
1 parent df18fa2 commit 985fb96

File tree

4 files changed

+112
-19
lines changed

4 files changed

+112
-19
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ async def all_collections(self, **kwargs) -> stac_types.Collections:
238238
limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10)))
239239
token = request.query_params.get("token")
240240

241-
collections, next_token = await self.database.get_all_collections(
241+
collections, next_token, prev_token = await self.database.get_all_collections(
242242
token=token, limit=limit, request=request
243243
)
244244

@@ -255,6 +255,10 @@ async def all_collections(self, **kwargs) -> stac_types.Collections:
255255
if next_token:
256256
next_link = PagingLinks(next=next_token, request=request).link_next()
257257
links.append(next_link)
258+
259+
if prev_token:
260+
prev_token = PagingLinks(prev=prev_token, request=request).link_previous()
261+
links.append(prev_token)
258262

259263
return stac_types.Collections(collections=collections, links=links)
260264

@@ -343,7 +347,7 @@ async def item_collection(
343347
search = self.database.apply_bbox_filter(search=search, bbox=bbox)
344348

345349
limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10)))
346-
items, maybe_count, next_token = await self.database.execute_search(
350+
items, maybe_count, next_token, prev_token = await self.database.execute_search(
347351
search=search,
348352
limit=limit,
349353
sort=None,
@@ -356,7 +360,8 @@ async def item_collection(
356360
self.item_serializer.db_to_stac(item, base_url=base_url) for item in items
357361
]
358362

359-
links = await PagingLinks(request=request, next=next_token).get_links()
363+
links = await PagingLinks(request=request, next=next_token, prev=prev_token, collection_id=collection_id).get_links()
364+
360365

361366
return stac_types.ItemCollection(
362367
type="FeatureCollection",
@@ -567,8 +572,8 @@ async def post_search(
567572
limit = 10
568573
if search_request.limit:
569574
limit = search_request.limit
570-
571-
items, maybe_count, next_token = await self.database.execute_search(
575+
576+
items, maybe_count, next_token, prev_token = await self.database.execute_search(
572577
search=search,
573578
limit=limit,
574579
token=search_request.token,
@@ -593,7 +598,7 @@ async def post_search(
593598
)
594599
for item in items
595600
]
596-
links = await PagingLinks(request=request, next=next_token).get_links()
601+
links = await PagingLinks(request=request, next=next_token, prev=prev_token).get_links()
597602

598603
return stac_types.ItemCollection(
599604
type="FeatureCollection",

stac_fastapi/core/stac_fastapi/core/models/links.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ class PagingLinks(BaseLinks):
179179
"""Create links for paging."""
180180

181181
next: Optional[str] = attr.ib(kw_only=True, default=None)
182+
prev: Optional[str] = attr.ib(kw_only=True, default=None)
183+
collection_id: Optional[str] = attr.ib(default=None)
182184

183185
def link_next(self) -> Optional[Dict[str, Any]]:
184186
"""Create link for next page."""
@@ -203,3 +205,49 @@ def link_next(self) -> Optional[Dict[str, Any]]:
203205
}
204206

205207
return None
208+
209+
210+
def link_previous(self) -> Optional[Dict[str, Any]]:
211+
"""Create link for previous page."""
212+
if not self.prev:
213+
return None
214+
215+
method = self.request.method
216+
if method == "GET":
217+
href = merge_params(self.url, {"token": self.prev})
218+
link = dict(
219+
rel=Relations.previous.value,
220+
type=MimeTypes.json.value,
221+
method=method,
222+
href=href,
223+
)
224+
return link
225+
if method == "POST":
226+
return {
227+
"rel": Relations.previous,
228+
"type": MimeTypes.json.value,
229+
"method": method,
230+
"href": f"{self.request.url}",
231+
"body": {**self.request.postbody, "token": self.prev},
232+
}
233+
return None
234+
235+
def link_parent(self) -> Optional[Dict[str, Any]]:
236+
"""Create the `parent` link if collection_id is provided."""
237+
if self.collection_id:
238+
return dict(
239+
rel=Relations.parent.value,
240+
type=MimeTypes.json.value,
241+
href=urljoin(self.base_url, f"collections/{self.collection_id}"),
242+
)
243+
return None
244+
245+
def link_collection(self) -> Optional[Dict[str, Any]]:
246+
"""Create the `collection` link if collection_id is provided."""
247+
if self.collection_id:
248+
return dict(
249+
rel=Relations.collection.value,
250+
type=MimeTypes.json.value,
251+
href=urljoin(self.base_url, f"collections/{self.collection_id}"),
252+
)
253+
return None

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __attrs_post_init__(self):
171171

172172
async def get_all_collections(
173173
self, token: Optional[str], limit: int, request: Request
174-
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
174+
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[str]]:
175175
"""Retrieve a list of all collections from Elasticsearch, supporting pagination.
176176
177177
Args:
@@ -189,11 +189,15 @@ async def get_all_collections(
189189
index=COLLECTIONS_INDEX,
190190
body={
191191
"sort": [{"id": {"order": "asc"}}],
192-
"size": limit,
192+
"size": limit + 1,
193193
**({"search_after": search_after} if search_after is not None else {}),
194194
},
195195
)
196196

197+
has_more = len(hits) > limit
198+
if has_more:
199+
hits = hits[:limit]
200+
197201
hits = response["hits"]["hits"]
198202
collections = [
199203
self.collection_serializer.db_to_stac(
@@ -205,8 +209,12 @@ async def get_all_collections(
205209
next_token = None
206210
if len(hits) == limit:
207211
next_token = hits[-1]["sort"][0]
212+
213+
prev_token = None
214+
if token:
215+
prev_token = token
208216

209-
return collections, next_token
217+
return collections, next_token, prev_token
210218

211219
async def get_one_item(self, collection_id: str, item_id: str) -> Dict:
212220
"""Retrieve a single item from the database.
@@ -506,7 +514,8 @@ async def execute_search(
506514
collection_ids: Optional[List[str]],
507515
datetime_search: Dict[str, Optional[str]],
508516
ignore_unavailable: bool = True,
509-
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]:
517+
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str], Optional[str]]:
518+
510519
"""Execute a search query with limit and other optional parameters.
511520
512521
Args:
@@ -579,6 +588,11 @@ async def execute_search(
579588
if hits and (sort_array := hits[limit - 1].get("sort")):
580589
next_token = urlsafe_b64encode(orjson.dumps(sort_array)).decode()
581590

591+
prev_token = None
592+
if token and hits:
593+
if hits and (first_item_sort := hits[0].get("sort")):
594+
prev_token = urlsafe_b64encode(orjson.dumps(first_item_sort)).decode()
595+
582596
matched = (
583597
es_response["hits"]["total"]["value"]
584598
if es_response["hits"]["total"]["relation"] == "eq"
@@ -590,7 +604,8 @@ async def execute_search(
590604
except Exception as e:
591605
logger.error(f"Count task failed: {e}")
592606

593-
return items, matched, next_token
607+
return items, matched, next_token, prev_token
608+
594609

595610
""" AGGREGATE LOGIC """
596611

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __attrs_post_init__(self):
154154

155155
async def get_all_collections(
156156
self, token: Optional[str], limit: int, request: Request
157-
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
157+
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[str]]:
158158
"""
159159
Retrieve a list of all collections from Opensearch, supporting pagination.
160160
@@ -167,20 +167,29 @@ async def get_all_collections(
167167
"""
168168
search_body = {
169169
"sort": [{"id": {"order": "asc"}}],
170-
"size": limit,
170+
"size": limit + 1,
171171
}
172+
173+
search_after = None
172174

173175
# Only add search_after to the query if token is not None and not empty
174176
if token:
175-
search_after = [token]
177+
try:
178+
search_after = orjson.loads(urlsafe_b64decode(token))
179+
except (ValueError, TypeError, orjson.JSONDecodeError):
180+
search_after = [token]
181+
if search_after:
176182
search_body["search_after"] = search_after
177183

178184
response = await self.client.search(
179185
index=COLLECTIONS_INDEX,
180186
body=search_body,
181187
)
182188

183-
hits = response["hits"]["hits"]
189+
hits = response["hits"]["hits"]
190+
has_more = len(hits) > limit
191+
if has_more:
192+
hits = hits[:limit]
184193
collections = [
185194
self.collection_serializer.db_to_stac(
186195
collection=hit["_source"], request=request, extensions=self.extensions
@@ -189,11 +198,20 @@ async def get_all_collections(
189198
]
190199

191200
next_token = None
192-
if len(hits) == limit:
201+
if has_more and hits:
193202
# Ensure we have a valid sort value for next_token
194203
next_token_values = hits[-1].get("sort")
195204
if next_token_values:
196-
next_token = next_token_values[0]
205+
next_token = urlsafe_b64encode(orjson.dumps(next_token_values)).decode()
206+
207+
prev_token = None
208+
if token and hits:
209+
first_token_values = hits[0].get("sort")
210+
if first_token_values:
211+
prev_token = urlsafe_b64encode(orjson.dumps(first_token_values)).decode()
212+
213+
return collections, next_token, prev_token
214+
197215

198216
return collections, next_token
199217

@@ -495,7 +513,9 @@ async def execute_search(
495513
collection_ids: Optional[List[str]],
496514
datetime_search: Dict[str, Optional[str]],
497515
ignore_unavailable: bool = True,
498-
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]:
516+
) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str], Optional[str]]:
517+
518+
499519
"""Execute a search query with limit and other optional parameters.
500520
501521
Args:
@@ -570,10 +590,15 @@ async def execute_search(
570590
items = (hit["_source"] for hit in hits[:limit])
571591

572592
next_token = None
593+
prev_token = None
573594
if len(hits) > limit and limit < max_result_window:
574595
if hits and (sort_array := hits[limit - 1].get("sort")):
575596
next_token = urlsafe_b64encode(orjson.dumps(sort_array)).decode()
576597

598+
if token and hits:
599+
if hits and (first_item_sort := hits[0].get("sort")):
600+
prev_token = urlsafe_b64encode(orjson.dumps(first_item_sort)).decode()
601+
577602
matched = (
578603
es_response["hits"]["total"]["value"]
579604
if es_response["hits"]["total"]["relation"] == "eq"
@@ -585,7 +610,7 @@ async def execute_search(
585610
except Exception as e:
586611
logger.error(f"Count task failed: {e}")
587612

588-
return items, matched, next_token
613+
return items, matched, next_token, prev_token
589614

590615
""" AGGREGATE LOGIC """
591616

0 commit comments

Comments
 (0)