Skip to content

Commit 915c775

Browse files
committed
token fix
1 parent 5f26425 commit 915c775

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,21 @@ async def get_all_collections(
228228
"size": limit,
229229
}
230230

231+
# Handle search_after token - split by '|' to get all sort values
232+
search_after = None
231233
if token:
232-
body["search_after"] = [token]
234+
try:
235+
# The token should be a pipe-separated string of sort values
236+
# e.g., "2023-01-01T00:00:00Z|collection-1"
237+
search_after = token.split("|")
238+
# If the number of sort fields doesn't match token parts, ignore the token
239+
if len(search_after) != len(formatted_sort):
240+
search_after = None
241+
except Exception:
242+
search_after = None
243+
244+
if search_after is not None:
245+
body["search_after"] = search_after
233246

234247
# Build the query part of the body
235248
query_parts = []
@@ -353,7 +366,8 @@ async def get_all_collections(
353366
if len(hits) == limit:
354367
next_token_values = hits[-1].get("sort")
355368
if next_token_values:
356-
next_token = next_token_values[0]
369+
# Join all sort values with '|' to create the token
370+
next_token = "|".join(str(val) for val in next_token_values)
357371

358372
# Get the total count of collections
359373
matched = (

stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,16 @@ async def get_all_collections(
164164
query: Optional[Dict[str, Dict[str, Any]]] = None,
165165
datetime: Optional[str] = None,
166166
) -> Tuple[List[Dict[str, Any]], Optional[str], Optional[int]]:
167-
"""
168-
Retrieve a list of collections from Opensearch, supporting pagination.
167+
"""Retrieve a list of collections from OpenSearch, supporting pagination.
169168
170169
Args:
171170
token (Optional[str]): The pagination token.
172171
limit (int): The number of results to return.
173172
request (Request): The FastAPI request object.
174173
sort (Optional[List[Dict[str, Any]]]): Optional sort parameter from the request.
175174
q (Optional[List[str]]): Free text search terms.
176-
filter (Optional[Dict[str, Any]]): Structured filter in CQL2 format.
177175
query (Optional[Dict[str, Dict[str, Any]]]): Query extension parameters.
176+
filter (Optional[Dict[str, Any]]): Structured query in CQL2 format.
178177
datetime (Optional[str]): Temporal filter.
179178
180179
Returns:
@@ -213,8 +212,21 @@ async def get_all_collections(
213212
"size": limit,
214213
}
215214

215+
# Handle search_after token - split by '|' to get all sort values
216+
search_after = None
216217
if token:
217-
body["search_after"] = [token]
218+
try:
219+
# The token should be a pipe-separated string of sort values
220+
# e.g., "2023-01-01T00:00:00Z|collection-1"
221+
search_after = token.split("|")
222+
# If the number of sort fields doesn't match token parts, ignore the token
223+
if len(search_after) != len(formatted_sort):
224+
search_after = None
225+
except Exception:
226+
search_after = None
227+
228+
if search_after is not None:
229+
body["search_after"] = search_after
218230

219231
# Build the query part of the body
220232
query_parts = []
@@ -279,13 +291,15 @@ async def get_all_collections(
279291
search_dict = search.to_dict()
280292
if "query" in search_dict:
281293
query_parts.append(search_dict["query"])
294+
282295
except Exception as e:
283296
logger = logging.getLogger(__name__)
284297
logger.error(f"Error converting query to OpenSearch: {e}")
285298
# If there's an error, add a query that matches nothing
286299
query_parts.append({"bool": {"must_not": {"match_all": {}}}})
287300
raise
288301

302+
# Combine all query parts with AND logic if there are multiple
289303
datetime_filter = None
290304
if datetime:
291305
datetime_filter = self._apply_collection_datetime_filter(datetime)
@@ -336,7 +350,8 @@ async def get_all_collections(
336350
if len(hits) == limit:
337351
next_token_values = hits[-1].get("sort")
338352
if next_token_values:
339-
next_token = next_token_values[0]
353+
# Join all sort values with '|' to create the token
354+
next_token = "|".join(str(val) for val in next_token_values)
340355

341356
# Get the total count of collections
342357
matched = (

0 commit comments

Comments
 (0)