|
18 | 18 | from pydantic import BaseModel |
19 | 19 |
|
20 | 20 | from agent_memory_server.config import settings |
21 | | -from agent_memory_server.extraction import extract_discrete_memories |
22 | 21 | from agent_memory_server.llms import get_model_client |
23 | 22 |
|
24 | 23 |
|
@@ -279,133 +278,81 @@ async def create_test_conversation_with_context( |
279 | 278 | async def test_pronoun_grounding_integration_he_him(self): |
280 | 279 | """Integration test for he/him pronoun grounding with real LLM""" |
281 | 280 | example = ContextualGroundingBenchmark.get_pronoun_grounding_examples()[0] |
| 281 | + session_id = f"test-pronoun-{ulid.ULID()}" |
282 | 282 |
|
283 | | - # Create memory record and store it first |
284 | | - memory = await self.create_test_memory_with_context( |
285 | | - example["messages"][:-1], # Context |
286 | | - example["messages"][-1], # Target message with pronouns |
287 | | - example["context_date"], |
| 283 | + # Set up conversation context for cross-message grounding |
| 284 | + await self.create_test_conversation_with_context( |
| 285 | + example["messages"], example["context_date"], session_id |
288 | 286 | ) |
289 | 287 |
|
290 | | - # Store the memory so it can be found by extract_discrete_memories |
291 | | - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
292 | | - |
293 | | - adapter = await get_vectorstore_adapter() |
294 | | - await adapter.add_memories([memory]) |
295 | | - |
296 | | - # Extract memories using real LLM |
297 | | - await extract_discrete_memories([memory]) |
298 | | - |
299 | | - # Retrieve all memories to verify extraction occurred |
300 | | - all_memories = await adapter.search_memories( |
301 | | - query="", |
302 | | - limit=50, # Get all memories |
| 288 | + # Use thread-aware extraction |
| 289 | + from agent_memory_server.long_term_memory import ( |
| 290 | + extract_memories_from_session_thread, |
303 | 291 | ) |
304 | 292 |
|
305 | | - # Find the original memory by session_id and verify it was processed |
306 | | - session_memories = [ |
307 | | - m for m in all_memories.memories if m.session_id == memory.session_id |
308 | | - ] |
309 | | - |
310 | | - # Should find the original message memory that was processed |
311 | | - assert ( |
312 | | - len(session_memories) >= 1 |
313 | | - ), f"No memories found in session {memory.session_id}" |
314 | | - |
315 | | - # Find our specific memory in the results |
316 | | - processed_memory = next( |
317 | | - (m for m in session_memories if m.id == memory.id), None |
| 293 | + extracted_memories = await extract_memories_from_session_thread( |
| 294 | + session_id=session_id, |
| 295 | + namespace="test-namespace", |
| 296 | + user_id="test-integration-user", |
318 | 297 | ) |
319 | 298 |
|
320 | | - if processed_memory is None: |
321 | | - # If we can't find by ID, try to find any memory in the session with discrete_memory_extracted = "t" |
322 | | - processed_memory = next( |
323 | | - (m for m in session_memories if m.discrete_memory_extracted == "t"), |
324 | | - None, |
325 | | - ) |
326 | | - |
327 | | - assert ( |
328 | | - processed_memory is not None |
329 | | - ), f"Could not find processed memory {memory.id} in session" |
330 | | - assert processed_memory.discrete_memory_extracted == "t" |
| 299 | + # Verify extraction was successful |
| 300 | + assert len(extracted_memories) >= 1, "Expected at least one extracted memory" |
331 | 301 |
|
332 | | - # Should also find extracted discrete memories |
333 | | - discrete_memories = [ |
334 | | - m |
335 | | - for m in all_memories.memories |
336 | | - if m.memory_type in ["episodic", "semantic"] |
337 | | - ] |
338 | | - assert ( |
339 | | - len(discrete_memories) >= 1 |
340 | | - ), "Expected at least one discrete memory to be extracted" |
| 302 | + # Check that pronoun grounding occurred |
| 303 | + all_memory_text = " ".join([mem.text for mem in extracted_memories]) |
| 304 | + print(f"Extracted memories: {all_memory_text}") |
341 | 305 |
|
342 | | - # Note: Full evaluation with LLM judge will be implemented in subsequent tests |
| 306 | + # Should mention "John" instead of leaving "he/him" unresolved |
| 307 | + assert "john" in all_memory_text.lower(), "Should contain grounded name 'John'" |
343 | 308 |
|
344 | 309 | async def test_temporal_grounding_integration_last_year(self): |
345 | 310 | """Integration test for temporal grounding with real LLM""" |
346 | 311 | example = ContextualGroundingBenchmark.get_temporal_grounding_examples()[0] |
| 312 | + session_id = f"test-temporal-{ulid.ULID()}" |
347 | 313 |
|
348 | | - memory = await self.create_test_memory_with_context( |
349 | | - example["messages"][:-1], example["messages"][-1], example["context_date"] |
| 314 | + # Set up conversation context |
| 315 | + await self.create_test_conversation_with_context( |
| 316 | + example["messages"], example["context_date"], session_id |
350 | 317 | ) |
351 | 318 |
|
352 | | - # Store and extract |
353 | | - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
354 | | - |
355 | | - adapter = await get_vectorstore_adapter() |
356 | | - await adapter.add_memories([memory]) |
357 | | - await extract_discrete_memories([memory]) |
358 | | - |
359 | | - # Check extraction was successful - search by session_id since ID search may not work reliably |
360 | | - from agent_memory_server.filters import MemoryType, SessionId |
361 | | - |
362 | | - updated_memories = await adapter.search_memories( |
363 | | - query="", |
364 | | - session_id=SessionId(eq=memory.session_id), |
365 | | - memory_type=MemoryType(eq="message"), |
366 | | - limit=10, |
| 319 | + # Use thread-aware extraction |
| 320 | + from agent_memory_server.long_term_memory import ( |
| 321 | + extract_memories_from_session_thread, |
367 | 322 | ) |
368 | | - # Find our specific memory in the results |
369 | | - target_memory = next( |
370 | | - (m for m in updated_memories.memories if m.id == memory.id), None |
| 323 | + |
| 324 | + extracted_memories = await extract_memories_from_session_thread( |
| 325 | + session_id=session_id, |
| 326 | + namespace="test-namespace", |
| 327 | + user_id="test-integration-user", |
371 | 328 | ) |
372 | | - assert ( |
373 | | - target_memory is not None |
374 | | - ), f"Could not find memory {memory.id} after extraction" |
375 | | - assert target_memory.discrete_memory_extracted == "t" |
| 329 | + |
| 330 | + # Verify extraction was successful |
| 331 | + assert len(extracted_memories) >= 1, "Expected at least one extracted memory" |
376 | 332 |
|
377 | 333 | async def test_spatial_grounding_integration_there(self): |
378 | 334 | """Integration test for spatial grounding with real LLM""" |
379 | 335 | example = ContextualGroundingBenchmark.get_spatial_grounding_examples()[0] |
| 336 | + session_id = f"test-spatial-{ulid.ULID()}" |
380 | 337 |
|
381 | | - memory = await self.create_test_memory_with_context( |
382 | | - example["messages"][:-1], example["messages"][-1], example["context_date"] |
| 338 | + # Set up conversation context |
| 339 | + await self.create_test_conversation_with_context( |
| 340 | + example["messages"], example["context_date"], session_id |
383 | 341 | ) |
384 | 342 |
|
385 | | - # Store and extract |
386 | | - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
387 | | - |
388 | | - adapter = await get_vectorstore_adapter() |
389 | | - await adapter.add_memories([memory]) |
390 | | - await extract_discrete_memories([memory]) |
391 | | - |
392 | | - # Check extraction was successful - search by session_id since ID search may not work reliably |
393 | | - from agent_memory_server.filters import MemoryType, SessionId |
394 | | - |
395 | | - updated_memories = await adapter.search_memories( |
396 | | - query="", |
397 | | - session_id=SessionId(eq=memory.session_id), |
398 | | - memory_type=MemoryType(eq="message"), |
399 | | - limit=10, |
| 343 | + # Use thread-aware extraction |
| 344 | + from agent_memory_server.long_term_memory import ( |
| 345 | + extract_memories_from_session_thread, |
400 | 346 | ) |
401 | | - # Find our specific memory in the results |
402 | | - target_memory = next( |
403 | | - (m for m in updated_memories.memories if m.id == memory.id), None |
| 347 | + |
| 348 | + extracted_memories = await extract_memories_from_session_thread( |
| 349 | + session_id=session_id, |
| 350 | + namespace="test-namespace", |
| 351 | + user_id="test-integration-user", |
404 | 352 | ) |
405 | | - assert ( |
406 | | - target_memory is not None |
407 | | - ), f"Could not find memory {memory.id} after extraction" |
408 | | - assert target_memory.discrete_memory_extracted == "t" |
| 353 | + |
| 354 | + # Verify extraction was successful |
| 355 | + assert len(extracted_memories) >= 1, "Expected at least one extracted memory" |
409 | 356 |
|
410 | 357 | @pytest.mark.requires_api_keys |
411 | 358 | async def test_comprehensive_grounding_evaluation_with_judge(self): |
@@ -526,42 +473,26 @@ async def test_model_comparison_grounding_quality(self): |
526 | 473 | settings.generation_model = model |
527 | 474 |
|
528 | 475 | try: |
529 | | - memory = await self.create_test_memory_with_context( |
530 | | - example["messages"][:-1], |
531 | | - example["messages"][-1], |
532 | | - example["context_date"], |
533 | | - ) |
| 476 | + session_id = f"test-model-comparison-{ulid.ULID()}" |
534 | 477 |
|
535 | | - # Store the memory so it can be found by extract_discrete_memories |
536 | | - from agent_memory_server.vectorstore_factory import ( |
537 | | - get_vectorstore_adapter, |
| 478 | + # Set up conversation context |
| 479 | + await self.create_test_conversation_with_context( |
| 480 | + example["messages"], example["context_date"], session_id |
538 | 481 | ) |
539 | 482 |
|
540 | | - adapter = await get_vectorstore_adapter() |
541 | | - await adapter.add_memories([memory]) |
542 | | - |
543 | | - await extract_discrete_memories([memory]) |
544 | | - |
545 | | - # Check if extraction was successful by searching for the memory |
546 | | - from agent_memory_server.filters import MemoryType, SessionId |
547 | | - |
548 | | - updated_memories = await adapter.search_memories( |
549 | | - query="", |
550 | | - session_id=SessionId(eq=memory.session_id), |
551 | | - memory_type=MemoryType(eq="message"), |
552 | | - limit=10, |
| 483 | + # Use thread-aware extraction |
| 484 | + from agent_memory_server.long_term_memory import ( |
| 485 | + extract_memories_from_session_thread, |
553 | 486 | ) |
554 | 487 |
|
555 | | - # Find our specific memory in the results |
556 | | - target_memory = next( |
557 | | - (m for m in updated_memories.memories if m.id == memory.id), |
558 | | - None, |
559 | | - ) |
560 | | - success = ( |
561 | | - target_memory is not None |
562 | | - and target_memory.discrete_memory_extracted == "t" |
| 488 | + extracted_memories = await extract_memories_from_session_thread( |
| 489 | + session_id=session_id, |
| 490 | + namespace="test-namespace", |
| 491 | + user_id="test-integration-user", |
563 | 492 | ) |
564 | 493 |
|
| 494 | + success = len(extracted_memories) >= 1 |
| 495 | + |
565 | 496 | # Record success/failure for this model |
566 | 497 | results_by_model[model] = {"success": success, "model": model} |
567 | 498 |
|
|
0 commit comments