|
18 | 18 | from pydantic import BaseModel
|
19 | 19 |
|
20 | 20 | from agent_memory_server.config import settings
|
21 |
| -from agent_memory_server.extraction import extract_discrete_memories |
22 | 21 | from agent_memory_server.llms import get_model_client
|
23 | 22 |
|
24 | 23 |
|
@@ -279,133 +278,81 @@ async def create_test_conversation_with_context(
|
279 | 278 | async def test_pronoun_grounding_integration_he_him(self):
|
280 | 279 | """Integration test for he/him pronoun grounding with real LLM"""
|
281 | 280 | example = ContextualGroundingBenchmark.get_pronoun_grounding_examples()[0]
|
| 281 | + session_id = f"test-pronoun-{ulid.ULID()}" |
282 | 282 |
|
283 |
| - # Create memory record and store it first |
284 |
| - memory = await self.create_test_memory_with_context( |
285 |
| - example["messages"][:-1], # Context |
286 |
| - example["messages"][-1], # Target message with pronouns |
287 |
| - example["context_date"], |
| 283 | + # Set up conversation context for cross-message grounding |
| 284 | + await self.create_test_conversation_with_context( |
| 285 | + example["messages"], example["context_date"], session_id |
288 | 286 | )
|
289 | 287 |
|
290 |
| - # Store the memory so it can be found by extract_discrete_memories |
291 |
| - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
292 |
| - |
293 |
| - adapter = await get_vectorstore_adapter() |
294 |
| - await adapter.add_memories([memory]) |
295 |
| - |
296 |
| - # Extract memories using real LLM |
297 |
| - await extract_discrete_memories([memory]) |
298 |
| - |
299 |
| - # Retrieve all memories to verify extraction occurred |
300 |
| - all_memories = await adapter.search_memories( |
301 |
| - query="", |
302 |
| - limit=50, # Get all memories |
| 288 | + # Use thread-aware extraction |
| 289 | + from agent_memory_server.long_term_memory import ( |
| 290 | + extract_memories_from_session_thread, |
303 | 291 | )
|
304 | 292 |
|
305 |
| - # Find the original memory by session_id and verify it was processed |
306 |
| - session_memories = [ |
307 |
| - m for m in all_memories.memories if m.session_id == memory.session_id |
308 |
| - ] |
309 |
| - |
310 |
| - # Should find the original message memory that was processed |
311 |
| - assert ( |
312 |
| - len(session_memories) >= 1 |
313 |
| - ), f"No memories found in session {memory.session_id}" |
314 |
| - |
315 |
| - # Find our specific memory in the results |
316 |
| - processed_memory = next( |
317 |
| - (m for m in session_memories if m.id == memory.id), None |
| 293 | + extracted_memories = await extract_memories_from_session_thread( |
| 294 | + session_id=session_id, |
| 295 | + namespace="test-namespace", |
| 296 | + user_id="test-integration-user", |
318 | 297 | )
|
319 | 298 |
|
320 |
| - if processed_memory is None: |
321 |
| - # If we can't find by ID, try to find any memory in the session with discrete_memory_extracted = "t" |
322 |
| - processed_memory = next( |
323 |
| - (m for m in session_memories if m.discrete_memory_extracted == "t"), |
324 |
| - None, |
325 |
| - ) |
326 |
| - |
327 |
| - assert ( |
328 |
| - processed_memory is not None |
329 |
| - ), f"Could not find processed memory {memory.id} in session" |
330 |
| - assert processed_memory.discrete_memory_extracted == "t" |
| 299 | + # Verify extraction was successful |
| 300 | + assert len(extracted_memories) >= 1, "Expected at least one extracted memory" |
331 | 301 |
|
332 |
| - # Should also find extracted discrete memories |
333 |
| - discrete_memories = [ |
334 |
| - m |
335 |
| - for m in all_memories.memories |
336 |
| - if m.memory_type in ["episodic", "semantic"] |
337 |
| - ] |
338 |
| - assert ( |
339 |
| - len(discrete_memories) >= 1 |
340 |
| - ), "Expected at least one discrete memory to be extracted" |
| 302 | + # Check that pronoun grounding occurred |
| 303 | + all_memory_text = " ".join([mem.text for mem in extracted_memories]) |
| 304 | + print(f"Extracted memories: {all_memory_text}") |
341 | 305 |
|
342 |
| - # Note: Full evaluation with LLM judge will be implemented in subsequent tests |
| 306 | + # Should mention "John" instead of leaving "he/him" unresolved |
| 307 | + assert "john" in all_memory_text.lower(), "Should contain grounded name 'John'" |
343 | 308 |
|
344 | 309 | async def test_temporal_grounding_integration_last_year(self):
|
345 | 310 | """Integration test for temporal grounding with real LLM"""
|
346 | 311 | example = ContextualGroundingBenchmark.get_temporal_grounding_examples()[0]
|
| 312 | + session_id = f"test-temporal-{ulid.ULID()}" |
347 | 313 |
|
348 |
| - memory = await self.create_test_memory_with_context( |
349 |
| - example["messages"][:-1], example["messages"][-1], example["context_date"] |
| 314 | + # Set up conversation context |
| 315 | + await self.create_test_conversation_with_context( |
| 316 | + example["messages"], example["context_date"], session_id |
350 | 317 | )
|
351 | 318 |
|
352 |
| - # Store and extract |
353 |
| - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
354 |
| - |
355 |
| - adapter = await get_vectorstore_adapter() |
356 |
| - await adapter.add_memories([memory]) |
357 |
| - await extract_discrete_memories([memory]) |
358 |
| - |
359 |
| - # Check extraction was successful - search by session_id since ID search may not work reliably |
360 |
| - from agent_memory_server.filters import MemoryType, SessionId |
361 |
| - |
362 |
| - updated_memories = await adapter.search_memories( |
363 |
| - query="", |
364 |
| - session_id=SessionId(eq=memory.session_id), |
365 |
| - memory_type=MemoryType(eq="message"), |
366 |
| - limit=10, |
| 319 | + # Use thread-aware extraction |
| 320 | + from agent_memory_server.long_term_memory import ( |
| 321 | + extract_memories_from_session_thread, |
367 | 322 | )
|
368 |
| - # Find our specific memory in the results |
369 |
| - target_memory = next( |
370 |
| - (m for m in updated_memories.memories if m.id == memory.id), None |
| 323 | + |
| 324 | + extracted_memories = await extract_memories_from_session_thread( |
| 325 | + session_id=session_id, |
| 326 | + namespace="test-namespace", |
| 327 | + user_id="test-integration-user", |
371 | 328 | )
|
372 |
| - assert ( |
373 |
| - target_memory is not None |
374 |
| - ), f"Could not find memory {memory.id} after extraction" |
375 |
| - assert target_memory.discrete_memory_extracted == "t" |
| 329 | + |
| 330 | + # Verify extraction was successful |
| 331 | + assert len(extracted_memories) >= 1, "Expected at least one extracted memory" |
376 | 332 |
|
377 | 333 | async def test_spatial_grounding_integration_there(self):
|
378 | 334 | """Integration test for spatial grounding with real LLM"""
|
379 | 335 | example = ContextualGroundingBenchmark.get_spatial_grounding_examples()[0]
|
| 336 | + session_id = f"test-spatial-{ulid.ULID()}" |
380 | 337 |
|
381 |
| - memory = await self.create_test_memory_with_context( |
382 |
| - example["messages"][:-1], example["messages"][-1], example["context_date"] |
| 338 | + # Set up conversation context |
| 339 | + await self.create_test_conversation_with_context( |
| 340 | + example["messages"], example["context_date"], session_id |
383 | 341 | )
|
384 | 342 |
|
385 |
| - # Store and extract |
386 |
| - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
387 |
| - |
388 |
| - adapter = await get_vectorstore_adapter() |
389 |
| - await adapter.add_memories([memory]) |
390 |
| - await extract_discrete_memories([memory]) |
391 |
| - |
392 |
| - # Check extraction was successful - search by session_id since ID search may not work reliably |
393 |
| - from agent_memory_server.filters import MemoryType, SessionId |
394 |
| - |
395 |
| - updated_memories = await adapter.search_memories( |
396 |
| - query="", |
397 |
| - session_id=SessionId(eq=memory.session_id), |
398 |
| - memory_type=MemoryType(eq="message"), |
399 |
| - limit=10, |
| 343 | + # Use thread-aware extraction |
| 344 | + from agent_memory_server.long_term_memory import ( |
| 345 | + extract_memories_from_session_thread, |
400 | 346 | )
|
401 |
| - # Find our specific memory in the results |
402 |
| - target_memory = next( |
403 |
| - (m for m in updated_memories.memories if m.id == memory.id), None |
| 347 | + |
| 348 | + extracted_memories = await extract_memories_from_session_thread( |
| 349 | + session_id=session_id, |
| 350 | + namespace="test-namespace", |
| 351 | + user_id="test-integration-user", |
404 | 352 | )
|
405 |
| - assert ( |
406 |
| - target_memory is not None |
407 |
| - ), f"Could not find memory {memory.id} after extraction" |
408 |
| - assert target_memory.discrete_memory_extracted == "t" |
| 353 | + |
| 354 | + # Verify extraction was successful |
| 355 | + assert len(extracted_memories) >= 1, "Expected at least one extracted memory" |
409 | 356 |
|
410 | 357 | @pytest.mark.requires_api_keys
|
411 | 358 | async def test_comprehensive_grounding_evaluation_with_judge(self):
|
@@ -526,42 +473,26 @@ async def test_model_comparison_grounding_quality(self):
|
526 | 473 | settings.generation_model = model
|
527 | 474 |
|
528 | 475 | try:
|
529 |
| - memory = await self.create_test_memory_with_context( |
530 |
| - example["messages"][:-1], |
531 |
| - example["messages"][-1], |
532 |
| - example["context_date"], |
533 |
| - ) |
| 476 | + session_id = f"test-model-comparison-{ulid.ULID()}" |
534 | 477 |
|
535 |
| - # Store the memory so it can be found by extract_discrete_memories |
536 |
| - from agent_memory_server.vectorstore_factory import ( |
537 |
| - get_vectorstore_adapter, |
| 478 | + # Set up conversation context |
| 479 | + await self.create_test_conversation_with_context( |
| 480 | + example["messages"], example["context_date"], session_id |
538 | 481 | )
|
539 | 482 |
|
540 |
| - adapter = await get_vectorstore_adapter() |
541 |
| - await adapter.add_memories([memory]) |
542 |
| - |
543 |
| - await extract_discrete_memories([memory]) |
544 |
| - |
545 |
| - # Check if extraction was successful by searching for the memory |
546 |
| - from agent_memory_server.filters import MemoryType, SessionId |
547 |
| - |
548 |
| - updated_memories = await adapter.search_memories( |
549 |
| - query="", |
550 |
| - session_id=SessionId(eq=memory.session_id), |
551 |
| - memory_type=MemoryType(eq="message"), |
552 |
| - limit=10, |
| 483 | + # Use thread-aware extraction |
| 484 | + from agent_memory_server.long_term_memory import ( |
| 485 | + extract_memories_from_session_thread, |
553 | 486 | )
|
554 | 487 |
|
555 |
| - # Find our specific memory in the results |
556 |
| - target_memory = next( |
557 |
| - (m for m in updated_memories.memories if m.id == memory.id), |
558 |
| - None, |
559 |
| - ) |
560 |
| - success = ( |
561 |
| - target_memory is not None |
562 |
| - and target_memory.discrete_memory_extracted == "t" |
| 488 | + extracted_memories = await extract_memories_from_session_thread( |
| 489 | + session_id=session_id, |
| 490 | + namespace="test-namespace", |
| 491 | + user_id="test-integration-user", |
563 | 492 | )
|
564 | 493 |
|
| 494 | + success = len(extracted_memories) >= 1 |
| 495 | + |
565 | 496 | # Record success/failure for this model
|
566 | 497 | results_by_model[model] = {"success": success, "model": model}
|
567 | 498 |
|
|
0 commit comments