|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import json |
| 4 | +from collections.abc import Iterable, Sequence |
| 5 | +from contextlib import asynccontextmanager |
| 6 | +from datetime import datetime, timedelta |
| 7 | +from typing import Any, cast |
| 8 | + |
3 | 9 | import pytest |
| 10 | +from openai.types.responses.response_output_message_param import ResponseOutputMessageParam |
| 11 | +from openai.types.responses.response_output_text_param import ResponseOutputTextParam |
| 12 | +from openai.types.responses.response_reasoning_item_param import ( |
| 13 | + ResponseReasoningItemParam, |
| 14 | + Summary, |
| 15 | +) |
| 16 | +from sqlalchemy import select, text, update |
| 17 | +from sqlalchemy.sql import Select |
4 | 18 |
|
5 | 19 | pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed |
6 | 20 |
|
|
16 | 30 | DB_URL = "sqlite+aiosqlite:///:memory:" |
17 | 31 |
|
18 | 32 |
|
| 33 | +def _make_message_item(item_id: str, text_value: str) -> TResponseInputItem: |
| 34 | + content: ResponseOutputTextParam = { |
| 35 | + "type": "output_text", |
| 36 | + "text": text_value, |
| 37 | + "annotations": [], |
| 38 | + } |
| 39 | + message: ResponseOutputMessageParam = { |
| 40 | + "id": item_id, |
| 41 | + "type": "message", |
| 42 | + "role": "assistant", |
| 43 | + "status": "completed", |
| 44 | + "content": [content], |
| 45 | + } |
| 46 | + return cast(TResponseInputItem, message) |
| 47 | + |
| 48 | + |
| 49 | +def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem: |
| 50 | + summary: Summary = {"type": "summary_text", "text": summary_text} |
| 51 | + reasoning: ResponseReasoningItemParam = { |
| 52 | + "id": item_id, |
| 53 | + "type": "reasoning", |
| 54 | + "summary": [summary], |
| 55 | + } |
| 56 | + return cast(TResponseInputItem, reasoning) |
| 57 | + |
| 58 | + |
| 59 | +def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]: |
| 60 | + result: list[str] = [] |
| 61 | + for item in items: |
| 62 | + item_dict = cast(dict[str, Any], item) |
| 63 | + result.append(cast(str, item_dict["id"])) |
| 64 | + return result |
| 65 | + |
| 66 | + |
19 | 67 | @pytest.fixture |
20 | 68 | def agent() -> Agent: |
21 | 69 | """Fixture for a basic agent with a fake model.""" |
@@ -151,3 +199,195 @@ async def test_add_empty_items_list(): |
151 | 199 |
|
152 | 200 | items_after_add = await session.get_items() |
153 | 201 | assert len(items_after_add) == 0 |
| 202 | + |
| 203 | + |
| 204 | +async def test_get_items_same_timestamp_consistent_order(): |
| 205 | + """Test that items with identical timestamps keep insertion order.""" |
| 206 | + session_id = "same_timestamp_test" |
| 207 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) |
| 208 | + |
| 209 | + older_item = _make_message_item("older_same_ts", "old") |
| 210 | + reasoning_item = _make_reasoning_item("rs_same_ts", "...") |
| 211 | + message_item = _make_message_item("msg_same_ts", "...") |
| 212 | + await session.add_items([older_item]) |
| 213 | + await session.add_items([reasoning_item, message_item]) |
| 214 | + |
| 215 | + async with session._session_factory() as sess: |
| 216 | + rows = await sess.execute( |
| 217 | + select(session._messages.c.id, session._messages.c.message_data).where( |
| 218 | + session._messages.c.session_id == session.session_id |
| 219 | + ) |
| 220 | + ) |
| 221 | + id_map = { |
| 222 | + json.loads(message_json)["id"]: row_id |
| 223 | + for row_id, message_json in rows.fetchall() |
| 224 | + } |
| 225 | + shared = datetime(2025, 10, 15, 17, 26, 39, 132483) |
| 226 | + older = shared - timedelta(milliseconds=1) |
| 227 | + await sess.execute( |
| 228 | + update(session._messages) |
| 229 | + .where(session._messages.c.id.in_( |
| 230 | + [ |
| 231 | + id_map["rs_same_ts"], |
| 232 | + id_map["msg_same_ts"], |
| 233 | + ] |
| 234 | + )) |
| 235 | + .values(created_at=shared) |
| 236 | + ) |
| 237 | + await sess.execute( |
| 238 | + update(session._messages) |
| 239 | + .where(session._messages.c.id == id_map["older_same_ts"]) |
| 240 | + .values(created_at=older) |
| 241 | + ) |
| 242 | + await sess.commit() |
| 243 | + |
| 244 | + real_factory = session._session_factory |
| 245 | + |
| 246 | + class FakeResult: |
| 247 | + def __init__(self, rows: Iterable[Any]): |
| 248 | + self._rows = list(rows) |
| 249 | + |
| 250 | + def all(self) -> list[Any]: |
| 251 | + return list(self._rows) |
| 252 | + |
| 253 | + def needs_shuffle(statement: Any) -> bool: |
| 254 | + if not isinstance(statement, Select): |
| 255 | + return False |
| 256 | + orderings = list(statement._order_by_clause) |
| 257 | + if not orderings: |
| 258 | + return False |
| 259 | + id_asc = session._messages.c.id.asc() |
| 260 | + id_desc = session._messages.c.id.desc() |
| 261 | + |
| 262 | + def references_id(clause) -> bool: |
| 263 | + try: |
| 264 | + return bool(clause.compare(id_asc) or clause.compare(id_desc)) |
| 265 | + except AttributeError: |
| 266 | + return False |
| 267 | + |
| 268 | + if any(references_id(clause) for clause in orderings): |
| 269 | + return False |
| 270 | + # Only shuffle queries that target the messages table. |
| 271 | + target_tables: set[str] = set() |
| 272 | + for from_clause in statement.get_final_froms(): |
| 273 | + name_attr = getattr(from_clause, "name", None) |
| 274 | + if isinstance(name_attr, str): |
| 275 | + target_tables.add(name_attr) |
| 276 | + table_name_obj = getattr(session._messages, "name", "") |
| 277 | + table_name = table_name_obj if isinstance(table_name_obj, str) else "" |
| 278 | + return bool(table_name in target_tables) |
| 279 | + |
| 280 | + @asynccontextmanager |
| 281 | + async def shuffled_session(): |
| 282 | + async with real_factory() as inner: |
| 283 | + original_execute = inner.execute |
| 284 | + |
| 285 | + async def execute_with_shuffle(statement: Any, *args: Any, **kwargs: Any) -> Any: |
| 286 | + result = await original_execute(statement, *args, **kwargs) |
| 287 | + if needs_shuffle(statement): |
| 288 | + rows = result.all() |
| 289 | + shuffled = list(rows) |
| 290 | + shuffled.reverse() |
| 291 | + return FakeResult(shuffled) |
| 292 | + return result |
| 293 | + |
| 294 | + cast(Any, inner).execute = execute_with_shuffle |
| 295 | + try: |
| 296 | + yield inner |
| 297 | + finally: |
| 298 | + cast(Any, inner).execute = original_execute |
| 299 | + |
| 300 | + session._session_factory = cast(Any, shuffled_session) |
| 301 | + try: |
| 302 | + retrieved = await session.get_items() |
| 303 | + assert _item_ids(retrieved) == ["older_same_ts", "rs_same_ts", "msg_same_ts"] |
| 304 | + |
| 305 | + latest_two = await session.get_items(limit=2) |
| 306 | + assert _item_ids(latest_two) == ["rs_same_ts", "msg_same_ts"] |
| 307 | + finally: |
| 308 | + session._session_factory = real_factory |
| 309 | + |
| 310 | + |
| 311 | +async def test_pop_item_same_timestamp_returns_latest(): |
| 312 | + """Test that pop_item returns the newest item when timestamps tie.""" |
| 313 | + session_id = "same_timestamp_pop_test" |
| 314 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) |
| 315 | + |
| 316 | + reasoning_item = _make_reasoning_item("rs_pop_same_ts", "...") |
| 317 | + message_item = _make_message_item("msg_pop_same_ts", "...") |
| 318 | + await session.add_items([reasoning_item, message_item]) |
| 319 | + |
| 320 | + async with session._session_factory() as sess: |
| 321 | + await sess.execute( |
| 322 | + text( |
| 323 | + "UPDATE agent_messages " |
| 324 | + "SET created_at = :created_at " |
| 325 | + "WHERE session_id = :session_id" |
| 326 | + ), |
| 327 | + { |
| 328 | + "created_at": "2025-10-15 17:26:39.132483", |
| 329 | + "session_id": session.session_id, |
| 330 | + }, |
| 331 | + ) |
| 332 | + await sess.commit() |
| 333 | + |
| 334 | + popped = await session.pop_item() |
| 335 | + assert popped is not None |
| 336 | + assert cast(dict[str, Any], popped)["id"] == "msg_pop_same_ts" |
| 337 | + |
| 338 | + remaining = await session.get_items() |
| 339 | + assert _item_ids(remaining) == ["rs_pop_same_ts"] |
| 340 | + |
| 341 | + |
| 342 | +async def test_get_items_orders_by_id_for_ties(): |
| 343 | + """Test that get_items adds id ordering to break timestamp ties.""" |
| 344 | + session_id = "order_by_id_test" |
| 345 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) |
| 346 | + |
| 347 | + await session.add_items( |
| 348 | + [ |
| 349 | + _make_reasoning_item("rs_first", "..."), |
| 350 | + _make_message_item("msg_second", "..."), |
| 351 | + ] |
| 352 | + ) |
| 353 | + |
| 354 | + real_factory = session._session_factory |
| 355 | + recorded: list[Any] = [] |
| 356 | + |
| 357 | + @asynccontextmanager |
| 358 | + async def wrapped_session(): |
| 359 | + async with real_factory() as inner: |
| 360 | + original_execute = inner.execute |
| 361 | + |
| 362 | + async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any: |
| 363 | + recorded.append(statement) |
| 364 | + return await original_execute(statement, *args, **kwargs) |
| 365 | + |
| 366 | + cast(Any, inner).execute = recording_execute |
| 367 | + try: |
| 368 | + yield inner |
| 369 | + finally: |
| 370 | + cast(Any, inner).execute = original_execute |
| 371 | + |
| 372 | + session._session_factory = cast(Any, wrapped_session) |
| 373 | + try: |
| 374 | + retrieved_full = await session.get_items() |
| 375 | + retrieved_limited = await session.get_items(limit=2) |
| 376 | + finally: |
| 377 | + session._session_factory = real_factory |
| 378 | + |
| 379 | + assert len(recorded) >= 2 |
| 380 | + orderings_full = [str(clause) for clause in recorded[0]._order_by_clause] |
| 381 | + assert orderings_full == [ |
| 382 | + "agent_messages.created_at ASC", |
| 383 | + "agent_messages.id ASC", |
| 384 | + ] |
| 385 | + |
| 386 | + orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause] |
| 387 | + assert orderings_limited == [ |
| 388 | + "agent_messages.created_at DESC", |
| 389 | + "agent_messages.id DESC", |
| 390 | + ] |
| 391 | + |
| 392 | + assert _item_ids(retrieved_full) == ["rs_first", "msg_second"] |
| 393 | + assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"] |
0 commit comments