Skip to content

Commit 754939b

Browse files
abrookinsclaude
andcommitted
Fix remaining integration tests to use thread-aware extraction
- Update test_pronoun_grounding_integration_he_him - Update test_temporal_grounding_integration_last_year - Update test_spatial_grounding_integration_there - Update test_model_comparison_grounding_quality - All tests now use create_test_conversation_with_context() and extract_memories_from_session_thread() 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent aca0d76 commit 754939b

File tree

1 file changed

+61
-130
lines changed

1 file changed

+61
-130
lines changed

tests/test_contextual_grounding_integration.py

Lines changed: 61 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from pydantic import BaseModel
1919

2020
from agent_memory_server.config import settings
21-
from agent_memory_server.extraction import extract_discrete_memories
2221
from agent_memory_server.llms import get_model_client
2322

2423

@@ -279,133 +278,81 @@ async def create_test_conversation_with_context(
279278
async def test_pronoun_grounding_integration_he_him(self):
280279
"""Integration test for he/him pronoun grounding with real LLM"""
281280
example = ContextualGroundingBenchmark.get_pronoun_grounding_examples()[0]
281+
session_id = f"test-pronoun-{ulid.ULID()}"
282282

283-
# Create memory record and store it first
284-
memory = await self.create_test_memory_with_context(
285-
example["messages"][:-1], # Context
286-
example["messages"][-1], # Target message with pronouns
287-
example["context_date"],
283+
# Set up conversation context for cross-message grounding
284+
await self.create_test_conversation_with_context(
285+
example["messages"], example["context_date"], session_id
288286
)
289287

290-
# Store the memory so it can be found by extract_discrete_memories
291-
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
292-
293-
adapter = await get_vectorstore_adapter()
294-
await adapter.add_memories([memory])
295-
296-
# Extract memories using real LLM
297-
await extract_discrete_memories([memory])
298-
299-
# Retrieve all memories to verify extraction occurred
300-
all_memories = await adapter.search_memories(
301-
query="",
302-
limit=50, # Get all memories
288+
# Use thread-aware extraction
289+
from agent_memory_server.long_term_memory import (
290+
extract_memories_from_session_thread,
303291
)
304292

305-
# Find the original memory by session_id and verify it was processed
306-
session_memories = [
307-
m for m in all_memories.memories if m.session_id == memory.session_id
308-
]
309-
310-
# Should find the original message memory that was processed
311-
assert (
312-
len(session_memories) >= 1
313-
), f"No memories found in session {memory.session_id}"
314-
315-
# Find our specific memory in the results
316-
processed_memory = next(
317-
(m for m in session_memories if m.id == memory.id), None
293+
extracted_memories = await extract_memories_from_session_thread(
294+
session_id=session_id,
295+
namespace="test-namespace",
296+
user_id="test-integration-user",
318297
)
319298

320-
if processed_memory is None:
321-
# If we can't find by ID, try to find any memory in the session with discrete_memory_extracted = "t"
322-
processed_memory = next(
323-
(m for m in session_memories if m.discrete_memory_extracted == "t"),
324-
None,
325-
)
326-
327-
assert (
328-
processed_memory is not None
329-
), f"Could not find processed memory {memory.id} in session"
330-
assert processed_memory.discrete_memory_extracted == "t"
299+
# Verify extraction was successful
300+
assert len(extracted_memories) >= 1, "Expected at least one extracted memory"
331301

332-
# Should also find extracted discrete memories
333-
discrete_memories = [
334-
m
335-
for m in all_memories.memories
336-
if m.memory_type in ["episodic", "semantic"]
337-
]
338-
assert (
339-
len(discrete_memories) >= 1
340-
), "Expected at least one discrete memory to be extracted"
302+
# Check that pronoun grounding occurred
303+
all_memory_text = " ".join([mem.text for mem in extracted_memories])
304+
print(f"Extracted memories: {all_memory_text}")
341305

342-
# Note: Full evaluation with LLM judge will be implemented in subsequent tests
306+
# Should mention "John" instead of leaving "he/him" unresolved
307+
assert "john" in all_memory_text.lower(), "Should contain grounded name 'John'"
343308

344309
async def test_temporal_grounding_integration_last_year(self):
345310
"""Integration test for temporal grounding with real LLM"""
346311
example = ContextualGroundingBenchmark.get_temporal_grounding_examples()[0]
312+
session_id = f"test-temporal-{ulid.ULID()}"
347313

348-
memory = await self.create_test_memory_with_context(
349-
example["messages"][:-1], example["messages"][-1], example["context_date"]
314+
# Set up conversation context
315+
await self.create_test_conversation_with_context(
316+
example["messages"], example["context_date"], session_id
350317
)
351318

352-
# Store and extract
353-
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
354-
355-
adapter = await get_vectorstore_adapter()
356-
await adapter.add_memories([memory])
357-
await extract_discrete_memories([memory])
358-
359-
# Check extraction was successful - search by session_id since ID search may not work reliably
360-
from agent_memory_server.filters import MemoryType, SessionId
361-
362-
updated_memories = await adapter.search_memories(
363-
query="",
364-
session_id=SessionId(eq=memory.session_id),
365-
memory_type=MemoryType(eq="message"),
366-
limit=10,
319+
# Use thread-aware extraction
320+
from agent_memory_server.long_term_memory import (
321+
extract_memories_from_session_thread,
367322
)
368-
# Find our specific memory in the results
369-
target_memory = next(
370-
(m for m in updated_memories.memories if m.id == memory.id), None
323+
324+
extracted_memories = await extract_memories_from_session_thread(
325+
session_id=session_id,
326+
namespace="test-namespace",
327+
user_id="test-integration-user",
371328
)
372-
assert (
373-
target_memory is not None
374-
), f"Could not find memory {memory.id} after extraction"
375-
assert target_memory.discrete_memory_extracted == "t"
329+
330+
# Verify extraction was successful
331+
assert len(extracted_memories) >= 1, "Expected at least one extracted memory"
376332

377333
async def test_spatial_grounding_integration_there(self):
378334
"""Integration test for spatial grounding with real LLM"""
379335
example = ContextualGroundingBenchmark.get_spatial_grounding_examples()[0]
336+
session_id = f"test-spatial-{ulid.ULID()}"
380337

381-
memory = await self.create_test_memory_with_context(
382-
example["messages"][:-1], example["messages"][-1], example["context_date"]
338+
# Set up conversation context
339+
await self.create_test_conversation_with_context(
340+
example["messages"], example["context_date"], session_id
383341
)
384342

385-
# Store and extract
386-
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
387-
388-
adapter = await get_vectorstore_adapter()
389-
await adapter.add_memories([memory])
390-
await extract_discrete_memories([memory])
391-
392-
# Check extraction was successful - search by session_id since ID search may not work reliably
393-
from agent_memory_server.filters import MemoryType, SessionId
394-
395-
updated_memories = await adapter.search_memories(
396-
query="",
397-
session_id=SessionId(eq=memory.session_id),
398-
memory_type=MemoryType(eq="message"),
399-
limit=10,
343+
# Use thread-aware extraction
344+
from agent_memory_server.long_term_memory import (
345+
extract_memories_from_session_thread,
400346
)
401-
# Find our specific memory in the results
402-
target_memory = next(
403-
(m for m in updated_memories.memories if m.id == memory.id), None
347+
348+
extracted_memories = await extract_memories_from_session_thread(
349+
session_id=session_id,
350+
namespace="test-namespace",
351+
user_id="test-integration-user",
404352
)
405-
assert (
406-
target_memory is not None
407-
), f"Could not find memory {memory.id} after extraction"
408-
assert target_memory.discrete_memory_extracted == "t"
353+
354+
# Verify extraction was successful
355+
assert len(extracted_memories) >= 1, "Expected at least one extracted memory"
409356

410357
@pytest.mark.requires_api_keys
411358
async def test_comprehensive_grounding_evaluation_with_judge(self):
@@ -526,42 +473,26 @@ async def test_model_comparison_grounding_quality(self):
526473
settings.generation_model = model
527474

528475
try:
529-
memory = await self.create_test_memory_with_context(
530-
example["messages"][:-1],
531-
example["messages"][-1],
532-
example["context_date"],
533-
)
476+
session_id = f"test-model-comparison-{ulid.ULID()}"
534477

535-
# Store the memory so it can be found by extract_discrete_memories
536-
from agent_memory_server.vectorstore_factory import (
537-
get_vectorstore_adapter,
478+
# Set up conversation context
479+
await self.create_test_conversation_with_context(
480+
example["messages"], example["context_date"], session_id
538481
)
539482

540-
adapter = await get_vectorstore_adapter()
541-
await adapter.add_memories([memory])
542-
543-
await extract_discrete_memories([memory])
544-
545-
# Check if extraction was successful by searching for the memory
546-
from agent_memory_server.filters import MemoryType, SessionId
547-
548-
updated_memories = await adapter.search_memories(
549-
query="",
550-
session_id=SessionId(eq=memory.session_id),
551-
memory_type=MemoryType(eq="message"),
552-
limit=10,
483+
# Use thread-aware extraction
484+
from agent_memory_server.long_term_memory import (
485+
extract_memories_from_session_thread,
553486
)
554487

555-
# Find our specific memory in the results
556-
target_memory = next(
557-
(m for m in updated_memories.memories if m.id == memory.id),
558-
None,
559-
)
560-
success = (
561-
target_memory is not None
562-
and target_memory.discrete_memory_extracted == "t"
488+
extracted_memories = await extract_memories_from_session_thread(
489+
session_id=session_id,
490+
namespace="test-namespace",
491+
user_id="test-integration-user",
563492
)
564493

494+
success = len(extracted_memories) >= 1
495+
565496
# Record success/failure for this model
566497
results_by_model[model] = {"success": success, "model": model}
567498

0 commit comments

Comments
 (0)