Skip to content

feat: add enhanced memory management foundation #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/strands/agent/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Memory management system for agents.

This package provides enhanced memory management capabilities including:
- Memory categorization (active, cached, archived)
- Memory lifecycle management with automatic cleanup
- Memory usage monitoring and metrics
- Configurable memory thresholds and policies

The memory management system is designed to be backward compatible with
existing AgentState usage while providing advanced memory optimization
capabilities for complex multi-agent scenarios.
"""

from .config import MemoryCategory, MemoryConfig, MemoryThresholds
from .enhanced_state import EnhancedAgentState
from .lifecycle import MemoryLifecycleManager
from .metrics import MemoryMetrics, MemoryUsageStats

__all__ = [
"MemoryConfig",
"MemoryCategory",
"MemoryThresholds",
"EnhancedAgentState",
"MemoryLifecycleManager",
"MemoryMetrics",
"MemoryUsageStats",
]
97 changes: 97 additions & 0 deletions src/strands/agent/memory/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Memory management configuration and types."""

from dataclasses import dataclass
from enum import Enum
from typing import Optional


class MemoryCategory(Enum):
"""Categories for memory classification."""

ACTIVE = "active" # Currently active/working memory
CACHED = "cached" # Recently used but not active
ARCHIVED = "archived" # Historical data for potential retrieval
METADATA = "metadata" # System metadata and statistics


@dataclass
class MemoryThresholds:
"""Memory management thresholds and limits."""

# Size thresholds (in estimated tokens/bytes)
active_memory_limit: int = 8192 # ~8K tokens for active memory
cached_memory_limit: int = 32768 # ~32K tokens for cached memory
total_memory_limit: int = 131072 # ~128K tokens total limit

# Cleanup thresholds (percentages)
cleanup_threshold: float = 0.8 # Start cleanup at 80% of limit
emergency_threshold: float = 0.95 # Emergency cleanup at 95%

# Time-based thresholds (seconds)
cache_ttl: int = 3600 # Cache TTL: 1 hour
archive_after: int = 86400 # Archive after: 24 hours

# Cleanup ratios (how much to remove during cleanup)
cleanup_ratio: float = 0.3 # Remove 30% during cleanup
emergency_cleanup_ratio: float = 0.5 # Remove 50% during emergency


@dataclass
class MemoryConfig:
"""Configuration for memory management system."""

# Feature toggles
enable_categorization: bool = True # Enable memory categorization
enable_lifecycle: bool = True # Enable automatic lifecycle management
enable_metrics: bool = True # Enable memory metrics collection
enable_archival: bool = True # Enable memory archival

# Thresholds configuration
thresholds: Optional[MemoryThresholds] = None

# Cleanup strategy
cleanup_strategy: str = "lru" # LRU, FIFO, or custom

# Validation settings
strict_validation: bool = True # Strict JSON validation

def __post_init__(self) -> None:
"""Initialize default thresholds if not provided."""
if self.thresholds is None:
self.thresholds = MemoryThresholds()

@classmethod
def conservative(cls) -> "MemoryConfig":
"""Create conservative memory configuration with lower limits."""
return cls(
thresholds=MemoryThresholds(
active_memory_limit=4096,
cached_memory_limit=16384,
total_memory_limit=65536,
cleanup_threshold=0.7,
cleanup_ratio=0.4,
)
)

@classmethod
def aggressive(cls) -> "MemoryConfig":
"""Create aggressive memory configuration with higher limits."""
return cls(
thresholds=MemoryThresholds(
active_memory_limit=16384,
cached_memory_limit=65536,
total_memory_limit=262144,
cleanup_threshold=0.9,
cleanup_ratio=0.2,
)
)

@classmethod
def minimal(cls) -> "MemoryConfig":
"""Create minimal memory configuration with basic features only."""
return cls(
enable_lifecycle=False,
enable_metrics=False,
enable_archival=False,
thresholds=MemoryThresholds(active_memory_limit=2048, cached_memory_limit=8192, total_memory_limit=32768),
)
217 changes: 217 additions & 0 deletions src/strands/agent/memory/enhanced_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
"""Enhanced agent state with memory management capabilities."""

import copy
from typing import Any, Dict, Optional

from ..state import AgentState
from .config import MemoryCategory, MemoryConfig
from .lifecycle import MemoryLifecycleManager


class EnhancedAgentState(AgentState):
"""Enhanced AgentState with memory categorization, lifecycle management, and metrics.

This class extends the base AgentState to provide:
- Memory categorization (active, cached, archived, metadata)
- Automatic memory lifecycle management
- Memory usage monitoring and metrics
- Configurable memory thresholds and cleanup policies

The enhanced state maintains full backward compatibility with the base AgentState
interface while adding advanced memory management capabilities.
"""

def __init__(self, initial_state: Optional[Dict[str, Any]] = None, memory_config: Optional[MemoryConfig] = None):
"""Initialize EnhancedAgentState.

Args:
initial_state: Initial state dictionary (backward compatibility)
memory_config: Memory management configuration
"""
# Initialize base AgentState for backward compatibility
super().__init__(initial_state)

# Initialize memory management
self.memory_config = memory_config or MemoryConfig()
self.memory_manager = MemoryLifecycleManager(self.memory_config)

# Migrate existing state to memory manager if provided
if initial_state:
for key, value in initial_state.items():
self.memory_manager.add_item(key, value, MemoryCategory.ACTIVE)

def set(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE) -> None:
"""Set a value in the state with optional memory category.

Args:
key: The key to store the value under
value: The value to store (must be JSON serializable)
category: Memory category for the value

Raises:
ValueError: If key is invalid, or if value is not JSON serializable
"""
# Validate using parent class methods
self._validate_key(key)
self._validate_json_serializable(value)

# Store in both base state (for backward compatibility) and memory manager
super().set(key, value)

if self.memory_config.enable_categorization:
self.memory_manager.add_item(key, value, category)

def get(self, key: Optional[str] = None, category: Optional[MemoryCategory] = None) -> Any:
"""Get a value or entire state, optionally filtered by category.

Args:
key: The key to retrieve (if None, returns entire state object)
category: Optional memory category filter

Returns:
The stored value, filtered state dict, or None if not found
"""
if key is None:
# Return entire state, optionally filtered by category
if category is not None and self.memory_config.enable_categorization:
return copy.deepcopy(self.memory_manager.get_items_by_category(category))
else:
# Backward compatibility: return base state
return super().get()
else:
# Return specific key
if self.memory_config.enable_categorization:
return self.memory_manager.get_item(key)
else:
return super().get(key)

def delete(self, key: str) -> None:
"""Delete a specific key from the state.

Args:
key: The key to delete
"""
self._validate_key(key)

# Delete from both base state and memory manager
super().delete(key)

if self.memory_config.enable_categorization:
self.memory_manager.remove_item(key)

def get_by_category(self, category: MemoryCategory) -> Dict[str, Any]:
"""Get all items in a specific memory category.

Args:
category: The memory category to retrieve

Returns:
Dictionary of all items in the specified category
"""
if not self.memory_config.enable_categorization:
# If categorization disabled, return all items for any category
return super().get() or {}

return self.memory_manager.get_items_by_category(category)

def set_metadata(self, key: str, value: Any) -> None:
"""Set a metadata value (convenience method).

Args:
key: The key to store the metadata under
value: The metadata value
"""
self.set(key, value, MemoryCategory.METADATA)

def get_active_memory(self) -> Dict[str, Any]:
"""Get all active memory items (convenience method).

Returns:
Dictionary of all active memory items
"""
return self.get_by_category(MemoryCategory.ACTIVE)

def get_cached_memory(self) -> Dict[str, Any]:
"""Get all cached memory items (convenience method).

Returns:
Dictionary of all cached memory items
"""
return self.get_by_category(MemoryCategory.CACHED)

def cleanup_memory(self, force: bool = False) -> int:
"""Perform memory cleanup and return number of items removed.

Args:
force: Force cleanup even if lifecycle management is disabled

Returns:
Number of items removed during cleanup
"""
if not self.memory_config.enable_lifecycle and not force:
return 0

removed_count = self.memory_manager.cleanup_memory(force)

# Sync base state with memory manager after cleanup
self._sync_base_state()

return removed_count

def get_memory_stats(self) -> Dict[str, Any]:
"""Get comprehensive memory usage statistics.

Returns:
Dictionary containing memory usage statistics and metrics
"""
if not self.memory_config.enable_metrics:
# Return basic stats if metrics disabled
all_items = super().get() or {}
return {
"total_items": len(all_items),
"categories_enabled": False,
"lifecycle_enabled": self.memory_config.enable_lifecycle,
"metrics_enabled": False,
}

return self.memory_manager.get_memory_report()

def optimize_memory(self) -> Dict[str, Any]:
"""Optimize memory usage and return optimization results.

Returns:
Dictionary containing optimization results and statistics
"""
if not self.memory_config.enable_lifecycle:
return {"optimization_skipped": True, "reason": "lifecycle_disabled"}

optimization_results = self.memory_manager.optimize_memory()

# Sync base state after optimization
self._sync_base_state()

return optimization_results

def _sync_base_state(self) -> None:
"""Synchronize base state with memory manager state."""
if self.memory_config.enable_categorization:
# Update base state to match memory manager
all_items = self.memory_manager.get_all_items()
self._state = copy.deepcopy(all_items)

def configure_memory(self, config: MemoryConfig) -> None:
"""Update memory configuration.

Args:
config: New memory configuration
"""
self.memory_config = config
self.memory_manager.config = config

def get_memory_config(self) -> MemoryConfig:
"""Get current memory configuration.

Returns:
Current memory configuration
"""
return self.memory_config
Loading