Skip to content

Commit aca0d76

Browse files
abrookinsclaude
andcommitted
Fix contextual grounding integration tests to use thread-aware extraction
- Replace create_test_memory_with_context() with create_test_conversation_with_context() - Set up proper WorkingMemory with individual MemoryMessage objects - Use extract_memories_from_session_thread() instead of extract_discrete_memories() - Enable cross-message contextual grounding testing Results show pronoun grounding now works: 'I told him about...' → 'User told John about...' 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 8147121 commit aca0d76

File tree

1 file changed

+47
-36
lines changed

1 file changed

+47
-36
lines changed

tests/test_contextual_grounding_integration.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from agent_memory_server.config import settings
2121
from agent_memory_server.extraction import extract_discrete_memories
2222
from agent_memory_server.llms import get_model_client
23-
from agent_memory_server.models import MemoryRecord, MemoryTypeEnum
2423

2524

2625
class GroundingEvaluationResult(BaseModel):
@@ -244,23 +243,39 @@ async def evaluate_grounding(
244243
class TestContextualGroundingIntegration:
245244
"""Integration tests for contextual grounding with real LLM calls"""
246245

247-
async def create_test_memory_with_context(
248-
self, context_messages: list[str], target_message: str, context_date: datetime
249-
) -> MemoryRecord:
250-
"""Create a memory record with conversational context"""
251-
# Combine context messages and target message
252-
full_conversation = "\n".join(context_messages + [target_message])
253-
254-
return MemoryRecord(
255-
id=str(ulid.ULID()),
256-
text=full_conversation,
257-
memory_type=MemoryTypeEnum.MESSAGE,
258-
discrete_memory_extracted="f",
259-
session_id=f"test-integration-session-{ulid.ULID()}",
246+
async def create_test_conversation_with_context(
247+
self, all_messages: list[str], context_date: datetime, session_id: str
248+
) -> str:
249+
"""Create a test conversation with proper working memory setup for cross-message grounding"""
250+
from agent_memory_server.models import MemoryMessage, WorkingMemory
251+
from agent_memory_server.working_memory import set_working_memory
252+
253+
# Create individual MemoryMessage objects for each message in the conversation
254+
messages = []
255+
for i, message_text in enumerate(all_messages):
256+
messages.append(
257+
MemoryMessage(
258+
id=str(ulid.ULID()),
259+
role="user" if i % 2 == 0 else "assistant",
260+
content=message_text,
261+
timestamp=context_date.isoformat(),
262+
discrete_memory_extracted="f",
263+
)
264+
)
265+
266+
# Create working memory with the conversation
267+
working_memory = WorkingMemory(
268+
session_id=session_id,
260269
user_id="test-integration-user",
261-
timestamp=context_date.isoformat(),
270+
namespace="test-namespace",
271+
messages=messages,
272+
memories=[],
262273
)
263274

275+
# Store in working memory for thread-aware extraction
276+
await set_working_memory(working_memory)
277+
return session_id
278+
264279
async def test_pronoun_grounding_integration_he_him(self):
265280
"""Integration test for he/him pronoun grounding with real LLM"""
266281
example = ContextualGroundingBenchmark.get_pronoun_grounding_examples()[0]
@@ -407,35 +422,31 @@ async def test_comprehensive_grounding_evaluation_with_judge(self):
407422
] # Just first 2 for integration testing
408423

409424
for example in sample_examples:
410-
# Create memory and extract with real LLM
411-
memory = await self.create_test_memory_with_context(
412-
example["messages"][:-1],
413-
example["messages"][-1],
414-
example["context_date"],
425+
# Create a unique session for this test
426+
session_id = f"test-grounding-{ulid.ULID()}"
427+
428+
# Set up proper conversation context for cross-message grounding
429+
await self.create_test_conversation_with_context(
430+
example["messages"], example["context_date"], session_id
415431
)
416432

417433
original_text = example["messages"][-1]
418434

419-
# Store and extract
420-
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
421-
422-
adapter = await get_vectorstore_adapter()
423-
await adapter.add_memories([memory])
424-
await extract_discrete_memories([memory])
435+
# Use thread-aware extraction (the whole point of our implementation!)
436+
from agent_memory_server.long_term_memory import (
437+
extract_memories_from_session_thread,
438+
)
425439

426-
# Retrieve all extracted discrete memories to get the grounded text
427-
all_memories = await adapter.search_memories(query="", limit=50)
428-
discrete_memories = [
429-
m
430-
for m in all_memories.memories
431-
if m.memory_type in ["episodic", "semantic"]
432-
and m.session_id == memory.session_id
433-
]
440+
extracted_memories = await extract_memories_from_session_thread(
441+
session_id=session_id,
442+
namespace="test-namespace",
443+
user_id="test-integration-user",
444+
)
434445

435446
# Combine the grounded memories into a single text for evaluation
436447
grounded_text = (
437-
" ".join([dm.text for dm in discrete_memories])
438-
if discrete_memories
448+
" ".join([mem.text for mem in extracted_memories])
449+
if extracted_memories
439450
else original_text
440451
)
441452

0 commit comments

Comments
 (0)