diff --git a/src/strands/agent/memory/__init__.py b/src/strands/agent/memory/__init__.py new file mode 100644 index 000000000..bc26f2dff --- /dev/null +++ b/src/strands/agent/memory/__init__.py @@ -0,0 +1,27 @@ +"""Memory management system for agents. + +This package provides enhanced memory management capabilities including: +- Memory categorization (active, cached, archived) +- Memory lifecycle management with automatic cleanup +- Memory usage monitoring and metrics +- Configurable memory thresholds and policies + +The memory management system is designed to be backward compatible with +existing AgentState usage while providing advanced memory optimization +capabilities for complex multi-agent scenarios. +""" + +from .config import MemoryCategory, MemoryConfig, MemoryThresholds +from .enhanced_state import EnhancedAgentState +from .lifecycle import MemoryLifecycleManager +from .metrics import MemoryMetrics, MemoryUsageStats + +__all__ = [ + "MemoryConfig", + "MemoryCategory", + "MemoryThresholds", + "EnhancedAgentState", + "MemoryLifecycleManager", + "MemoryMetrics", + "MemoryUsageStats", +] diff --git a/src/strands/agent/memory/config.py b/src/strands/agent/memory/config.py new file mode 100644 index 000000000..99580e331 --- /dev/null +++ b/src/strands/agent/memory/config.py @@ -0,0 +1,97 @@ +"""Memory management configuration and types.""" + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class MemoryCategory(Enum): + """Categories for memory classification.""" + + ACTIVE = "active" # Currently active/working memory + CACHED = "cached" # Recently used but not active + ARCHIVED = "archived" # Historical data for potential retrieval + METADATA = "metadata" # System metadata and statistics + + +@dataclass +class MemoryThresholds: + """Memory management thresholds and limits.""" + + # Size thresholds (in estimated tokens/bytes) + active_memory_limit: int = 8192 # ~8K tokens for active memory + cached_memory_limit: int = 32768 # ~32K tokens for cached memory + total_memory_limit: int = 131072 # ~128K tokens total limit + + # Cleanup thresholds (percentages) + cleanup_threshold: float = 0.8 # Start cleanup at 80% of limit + emergency_threshold: float = 0.95 # Emergency cleanup at 95% + + # Time-based thresholds (seconds) + cache_ttl: int = 3600 # Cache TTL: 1 hour + archive_after: int = 86400 # Archive after: 24 hours + + # Cleanup ratios (how much to remove during cleanup) + cleanup_ratio: float = 0.3 # Remove 30% during cleanup + emergency_cleanup_ratio: float = 0.5 # Remove 50% during emergency + + +@dataclass +class MemoryConfig: + """Configuration for memory management system.""" + + # Feature toggles + enable_categorization: bool = True # Enable memory categorization + enable_lifecycle: bool = True # Enable automatic lifecycle management + enable_metrics: bool = True # Enable memory metrics collection + enable_archival: bool = True # Enable memory archival + + # Thresholds configuration + thresholds: Optional[MemoryThresholds] = None + + # Cleanup strategy + cleanup_strategy: str = "lru" # LRU, FIFO, or custom + + # Validation settings + strict_validation: bool = True # Strict JSON validation + + def __post_init__(self) -> None: + """Initialize default thresholds if not provided.""" + if self.thresholds is None: + self.thresholds = MemoryThresholds() + + @classmethod + def conservative(cls) -> "MemoryConfig": + """Create conservative memory configuration with lower limits.""" + return cls( + thresholds=MemoryThresholds( + active_memory_limit=4096, + cached_memory_limit=16384, + total_memory_limit=65536, + cleanup_threshold=0.7, + cleanup_ratio=0.4, + ) + ) + + @classmethod + def aggressive(cls) -> "MemoryConfig": + """Create aggressive memory configuration with higher limits.""" + return cls( + thresholds=MemoryThresholds( + active_memory_limit=16384, + cached_memory_limit=65536, + total_memory_limit=262144, + cleanup_threshold=0.9, + cleanup_ratio=0.2, + ) + ) + + @classmethod + def minimal(cls) -> "MemoryConfig": + """Create minimal memory configuration with basic features only.""" + return cls( + enable_lifecycle=False, + enable_metrics=False, + enable_archival=False, + thresholds=MemoryThresholds(active_memory_limit=2048, cached_memory_limit=8192, total_memory_limit=32768), + ) diff --git a/src/strands/agent/memory/enhanced_state.py b/src/strands/agent/memory/enhanced_state.py new file mode 100644 index 000000000..3289add0a --- /dev/null +++ b/src/strands/agent/memory/enhanced_state.py @@ -0,0 +1,217 @@ +"""Enhanced agent state with memory management capabilities.""" + +import copy +from typing import Any, Dict, Optional + +from ..state import AgentState +from .config import MemoryCategory, MemoryConfig +from .lifecycle import MemoryLifecycleManager + + +class EnhancedAgentState(AgentState): + """Enhanced AgentState with memory categorization, lifecycle management, and metrics. + + This class extends the base AgentState to provide: + - Memory categorization (active, cached, archived, metadata) + - Automatic memory lifecycle management + - Memory usage monitoring and metrics + - Configurable memory thresholds and cleanup policies + + The enhanced state maintains full backward compatibility with the base AgentState + interface while adding advanced memory management capabilities. + """ + + def __init__(self, initial_state: Optional[Dict[str, Any]] = None, memory_config: Optional[MemoryConfig] = None): + """Initialize EnhancedAgentState. + + Args: + initial_state: Initial state dictionary (backward compatibility) + memory_config: Memory management configuration + """ + # Initialize base AgentState for backward compatibility + super().__init__(initial_state) + + # Initialize memory management + self.memory_config = memory_config or MemoryConfig() + self.memory_manager = MemoryLifecycleManager(self.memory_config) + + # Migrate existing state to memory manager if provided + if initial_state: + for key, value in initial_state.items(): + self.memory_manager.add_item(key, value, MemoryCategory.ACTIVE) + + def set(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE) -> None: + """Set a value in the state with optional memory category. + + Args: + key: The key to store the value under + value: The value to store (must be JSON serializable) + category: Memory category for the value + + Raises: + ValueError: If key is invalid, or if value is not JSON serializable + """ + # Validate using parent class methods + self._validate_key(key) + self._validate_json_serializable(value) + + # Store in both base state (for backward compatibility) and memory manager + super().set(key, value) + + if self.memory_config.enable_categorization: + self.memory_manager.add_item(key, value, category) + + def get(self, key: Optional[str] = None, category: Optional[MemoryCategory] = None) -> Any: + """Get a value or entire state, optionally filtered by category. + + Args: + key: The key to retrieve (if None, returns entire state object) + category: Optional memory category filter + + Returns: + The stored value, filtered state dict, or None if not found + """ + if key is None: + # Return entire state, optionally filtered by category + if category is not None and self.memory_config.enable_categorization: + return copy.deepcopy(self.memory_manager.get_items_by_category(category)) + else: + # Backward compatibility: return base state + return super().get() + else: + # Return specific key + if self.memory_config.enable_categorization: + return self.memory_manager.get_item(key) + else: + return super().get(key) + + def delete(self, key: str) -> None: + """Delete a specific key from the state. + + Args: + key: The key to delete + """ + self._validate_key(key) + + # Delete from both base state and memory manager + super().delete(key) + + if self.memory_config.enable_categorization: + self.memory_manager.remove_item(key) + + def get_by_category(self, category: MemoryCategory) -> Dict[str, Any]: + """Get all items in a specific memory category. + + Args: + category: The memory category to retrieve + + Returns: + Dictionary of all items in the specified category + """ + if not self.memory_config.enable_categorization: + # If categorization disabled, return all items for any category + return super().get() or {} + + return self.memory_manager.get_items_by_category(category) + + def set_metadata(self, key: str, value: Any) -> None: + """Set a metadata value (convenience method). + + Args: + key: The key to store the metadata under + value: The metadata value + """ + self.set(key, value, MemoryCategory.METADATA) + + def get_active_memory(self) -> Dict[str, Any]: + """Get all active memory items (convenience method). + + Returns: + Dictionary of all active memory items + """ + return self.get_by_category(MemoryCategory.ACTIVE) + + def get_cached_memory(self) -> Dict[str, Any]: + """Get all cached memory items (convenience method). + + Returns: + Dictionary of all cached memory items + """ + return self.get_by_category(MemoryCategory.CACHED) + + def cleanup_memory(self, force: bool = False) -> int: + """Perform memory cleanup and return number of items removed. + + Args: + force: Force cleanup even if lifecycle management is disabled + + Returns: + Number of items removed during cleanup + """ + if not self.memory_config.enable_lifecycle and not force: + return 0 + + removed_count = self.memory_manager.cleanup_memory(force) + + # Sync base state with memory manager after cleanup + self._sync_base_state() + + return removed_count + + def get_memory_stats(self) -> Dict[str, Any]: + """Get comprehensive memory usage statistics. + + Returns: + Dictionary containing memory usage statistics and metrics + """ + if not self.memory_config.enable_metrics: + # Return basic stats if metrics disabled + all_items = super().get() or {} + return { + "total_items": len(all_items), + "categories_enabled": False, + "lifecycle_enabled": self.memory_config.enable_lifecycle, + "metrics_enabled": False, + } + + return self.memory_manager.get_memory_report() + + def optimize_memory(self) -> Dict[str, Any]: + """Optimize memory usage and return optimization results. + + Returns: + Dictionary containing optimization results and statistics + """ + if not self.memory_config.enable_lifecycle: + return {"optimization_skipped": True, "reason": "lifecycle_disabled"} + + optimization_results = self.memory_manager.optimize_memory() + + # Sync base state after optimization + self._sync_base_state() + + return optimization_results + + def _sync_base_state(self) -> None: + """Synchronize base state with memory manager state.""" + if self.memory_config.enable_categorization: + # Update base state to match memory manager + all_items = self.memory_manager.get_all_items() + self._state = copy.deepcopy(all_items) + + def configure_memory(self, config: MemoryConfig) -> None: + """Update memory configuration. + + Args: + config: New memory configuration + """ + self.memory_config = config + self.memory_manager.config = config + + def get_memory_config(self) -> MemoryConfig: + """Get current memory configuration. + + Returns: + Current memory configuration + """ + return self.memory_config diff --git a/src/strands/agent/memory/lifecycle.py b/src/strands/agent/memory/lifecycle.py new file mode 100644 index 000000000..3f25eaafe --- /dev/null +++ b/src/strands/agent/memory/lifecycle.py @@ -0,0 +1,260 @@ +"""Memory lifecycle management for automatic cleanup and archival.""" + +import time +from typing import Any, Dict, Optional, Protocol + +from .config import MemoryCategory, MemoryConfig, MemoryThresholds +from .metrics import MemoryMetrics + + +class MemoryItem(Protocol): + """Protocol for memory items that can be managed by lifecycle manager.""" + + category: MemoryCategory + created_at: float + last_accessed: float + access_count: int + size: int + + +class CategorizedMemoryItem: + """A memory item with categorization and lifecycle metadata.""" + + def __init__(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE): + """Initialize a categorized memory item. + + Args: + key: The key for this memory item + value: The value to store + category: Memory category for this item + """ + self.key = key + self.value = value + self.category = category + self.created_at = time.time() + self.last_accessed = time.time() + self.access_count = 0 + self.size = self._estimate_size(value) + + def _estimate_size(self, value: Any) -> int: + """Estimate the size of the value.""" + import json + + try: + return len(json.dumps(value).encode("utf-8")) + except (TypeError, ValueError): + return len(str(value).encode("utf-8")) + + def access(self) -> None: + """Record an access to this memory item.""" + self.last_accessed = time.time() + self.access_count += 1 + + def age(self) -> float: + """Get the age of this memory item in seconds.""" + return time.time() - self.created_at + + def idle_time(self) -> float: + """Get the idle time since last access in seconds.""" + return time.time() - self.last_accessed + + def should_demote(self, inactive_threshold: float) -> bool: + """Check if item should be demoted from active to cached.""" + return self.category == MemoryCategory.ACTIVE and self.idle_time() > inactive_threshold + + def should_archive(self, archive_threshold: float) -> bool: + """Check if item should be archived.""" + return self.category == MemoryCategory.CACHED and self.age() > archive_threshold + + +class MemoryLifecycleManager: + """Manages the lifecycle of memory items with automatic cleanup and archival.""" + + def __init__(self, config: MemoryConfig): + """Initialize the memory lifecycle manager. + + Args: + config: Memory management configuration + """ + self.config = config + # Ensure thresholds are initialized + assert config.thresholds is not None, "MemoryConfig must have initialized thresholds" + self.metrics = MemoryMetrics() + self._items: Dict[str, CategorizedMemoryItem] = {} + self._last_cleanup = time.time() + + @property + def thresholds(self) -> MemoryThresholds: + """Get memory thresholds, ensuring they are never None.""" + assert self.config.thresholds is not None + return self.config.thresholds + + def add_item(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE) -> None: + """Add or update a memory item.""" + item = CategorizedMemoryItem(key, value, category) + self._items[key] = item + + if category == MemoryCategory.ACTIVE: + self.metrics.record_promotion() + + # Trigger cleanup if needed + if self.config.enable_lifecycle: + self._check_and_cleanup() + + def get_item(self, key: str) -> Optional[Any]: + """Get a memory item by key, updating access statistics.""" + if key in self._items: + item = self._items[key] + item.access() + self.metrics.record_access(hit=True) + + # Promote to active if accessed frequently + if item.category == MemoryCategory.CACHED and self._should_promote(item): + item.category = MemoryCategory.ACTIVE + self.metrics.record_promotion() + + return item.value + else: + self.metrics.record_access(hit=False) + return None + + def remove_item(self, key: str) -> bool: + """Remove a memory item.""" + if key in self._items: + del self._items[key] + return True + return False + + def get_items_by_category(self, category: MemoryCategory) -> Dict[str, Any]: + """Get all items in a specific category.""" + return {key: item.value for key, item in self._items.items() if item.category == category} + + def get_all_items(self) -> Dict[str, Any]: + """Get all memory items (backward compatibility with AgentState).""" + return {key: item.value for key, item in self._items.items()} + + def cleanup_memory(self, force: bool = False) -> int: + """Perform memory cleanup and return number of items removed.""" + if not self.config.enable_lifecycle and not force: + return 0 + + removed_count = 0 + current_time = time.time() + + # First, demote old active items to cached + for item in list(self._items.values()): + if item.should_demote(self.thresholds.cache_ttl): + item.category = MemoryCategory.CACHED + self.metrics.record_demotion() + + # Archive old cached items + if self.config.enable_archival: + for _key, item in list(self._items.items()): + if item.should_archive(self.thresholds.archive_after): + item.category = MemoryCategory.ARCHIVED + self.metrics.record_archival() + + # Remove items if over memory limits + total_size = sum(item.size for item in self._items.values()) + if force or total_size > self.thresholds.total_memory_limit: + removed_count += self._emergency_cleanup() + + self.metrics.record_cleanup() + self._last_cleanup = current_time + self._update_metrics() + + return removed_count + + def _check_and_cleanup(self) -> None: + """Check if cleanup is needed and perform it.""" + total_size = sum(item.size for item in self._items.values()) + utilization = total_size / self.thresholds.total_memory_limit + + if utilization >= self.thresholds.cleanup_threshold: + self.cleanup_memory() + + # Update metrics periodically + current_time = time.time() + if current_time - self._last_cleanup > 300: # Update every 5 minutes + self._update_metrics() + self._last_cleanup = current_time + + def _should_promote(self, item: CategorizedMemoryItem) -> bool: + """Determine if a cached item should be promoted to active.""" + # Promote if accessed frequently in recent time + recent_accesses = item.access_count + age_hours = item.age() / 3600 + + # Simple heuristic: if accessed more than once per hour on average + return age_hours > 0 and (recent_accesses / age_hours) > 1.0 + + def _emergency_cleanup(self) -> int: + """Perform emergency cleanup when memory limits are exceeded.""" + removed_count = 0 + + # Sort items by priority (LRU-like: least recently used first) + items_by_priority = sorted(self._items.items(), key=lambda x: (x[1].category.value, x[1].last_accessed)) + + # Calculate how many items to remove + total_items = len(self._items) + target_removal = max(1, int(total_items * self.thresholds.emergency_cleanup_ratio)) + + # Remove least important items first + for key, item in items_by_priority[:target_removal]: + # Don't remove active metadata + if item.category != MemoryCategory.METADATA: + del self._items[key] + removed_count += 1 + + return removed_count + + def _update_metrics(self) -> None: + """Update memory usage statistics.""" + from .metrics import MemoryUsageStats + + stats = MemoryUsageStats() + + for item in self._items.values(): + stats.total_size += item.size + stats.total_items += 1 + + if item.category == MemoryCategory.ACTIVE: + stats.active_size += item.size + stats.active_items += 1 + elif item.category == MemoryCategory.CACHED: + stats.cached_size += item.size + stats.cached_items += 1 + elif item.category == MemoryCategory.ARCHIVED: + stats.archived_size += item.size + stats.archived_items += 1 + elif item.category == MemoryCategory.METADATA: + stats.metadata_size += item.size + stats.metadata_items += 1 + + self.metrics.update_stats(stats) + + def get_memory_report(self) -> Dict[str, Any]: + """Get a comprehensive memory usage report.""" + self._update_metrics() + return self.metrics.get_summary() + + def optimize_memory(self) -> Dict[str, Any]: + """Perform memory optimization and return optimization results.""" + initial_size = sum(item.size for item in self._items.values()) + initial_count = len(self._items) + + # Perform cleanup + removed_count = self.cleanup_memory(force=True) + + final_size = sum(item.size for item in self._items.values()) + final_count = len(self._items) + + return { + "initial_size": initial_size, + "final_size": final_size, + "size_reduction": initial_size - final_size, + "initial_count": initial_count, + "final_count": final_count, + "items_removed": removed_count, + "size_reduction_pct": int(((initial_size - final_size) / initial_size * 100)) if initial_size > 0 else 0, + } diff --git a/src/strands/agent/memory/metrics.py b/src/strands/agent/memory/metrics.py new file mode 100644 index 000000000..739edc85c --- /dev/null +++ b/src/strands/agent/memory/metrics.py @@ -0,0 +1,221 @@ +"""Memory usage metrics and monitoring.""" + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from .config import MemoryCategory + + +@dataclass +class MemoryUsageStats: + """Statistics for memory usage tracking.""" + + # Size metrics (estimated tokens/bytes) + total_size: int = 0 + active_size: int = 0 + cached_size: int = 0 + archived_size: int = 0 + metadata_size: int = 0 + + # Count metrics + total_items: int = 0 + active_items: int = 0 + cached_items: int = 0 + archived_items: int = 0 + metadata_items: int = 0 + + # Performance metrics + hit_rate: float = 0.0 # Cache hit rate + miss_rate: float = 0.0 # Cache miss rate + cleanup_count: int = 0 # Number of cleanups performed + last_cleanup: Optional[float] = None # Timestamp of last cleanup + + # Lifecycle metrics + promotions: int = 0 # Items promoted to active + demotions: int = 0 # Items demoted from active + archival_count: int = 0 # Items archived + + def utilization_ratio(self, limit: int) -> float: + """Calculate memory utilization ratio.""" + if limit == 0: + return 0.0 + return min(1.0, self.total_size / limit) + + def category_distribution(self) -> Dict[str, float]: + """Get distribution of memory across categories.""" + if self.total_size == 0: + return {cat.value: 0.0 for cat in MemoryCategory} + + return { + MemoryCategory.ACTIVE.value: self.active_size / self.total_size, + MemoryCategory.CACHED.value: self.cached_size / self.total_size, + MemoryCategory.ARCHIVED.value: self.archived_size / self.total_size, + MemoryCategory.METADATA.value: self.metadata_size / self.total_size, + } + + +@dataclass +class MemoryMetrics: + """Memory metrics collection and analysis.""" + + # Current statistics + stats: MemoryUsageStats = field(default_factory=MemoryUsageStats) + + # Historical data (last N measurements) + history: List[MemoryUsageStats] = field(default_factory=list) + max_history_size: int = 100 + + # Access tracking + access_count: int = 0 + hit_count: int = 0 + miss_count: int = 0 + + # Timing metrics + last_access_time: Optional[float] = None + creation_time: float = field(default_factory=time.time) + + def record_access(self, hit: bool = True) -> None: + """Record a memory access (hit or miss).""" + self.access_count += 1 + self.last_access_time = time.time() + + if hit: + self.hit_count += 1 + else: + self.miss_count += 1 + + # Update hit/miss rates + if self.access_count > 0: + self.stats.hit_rate = self.hit_count / self.access_count + self.stats.miss_rate = self.miss_count / self.access_count + + def record_cleanup(self) -> None: + """Record a memory cleanup operation.""" + self.stats.cleanup_count += 1 + self.stats.last_cleanup = time.time() + + def record_promotion(self) -> None: + """Record a memory item promotion.""" + self.stats.promotions += 1 + + def record_demotion(self) -> None: + """Record a memory item demotion.""" + self.stats.demotions += 1 + + def record_archival(self) -> None: + """Record a memory item archival.""" + self.stats.archival_count += 1 + + def update_stats(self, new_stats: MemoryUsageStats) -> None: + """Update current statistics and save to history.""" + # Save current stats to history + if len(self.history) >= self.max_history_size: + self.history.pop(0) # Remove oldest entry + + self.history.append(self.stats) + + # Preserve accumulated metrics from current stats + new_stats.hit_rate = self.stats.hit_rate + new_stats.miss_rate = self.stats.miss_rate + new_stats.cleanup_count = self.stats.cleanup_count + new_stats.promotions = self.stats.promotions + new_stats.demotions = self.stats.demotions + new_stats.archival_count = self.stats.archival_count + new_stats.last_cleanup = self.stats.last_cleanup + + self.stats = new_stats + + def estimate_item_size(self, value: Any) -> int: + """Estimate the size of a memory item in bytes/tokens.""" + try: + # Use JSON serialization size as rough estimate + json_str = json.dumps(value) + # Rough token estimate: ~4 characters per token + return len(json_str.encode("utf-8")) + except (TypeError, ValueError): + # Fallback for non-serializable objects + return len(str(value).encode("utf-8")) + + def get_trend_analysis(self, window: int = 10) -> Dict[str, float]: + """Analyze trends in memory usage over the last N measurements.""" + if len(self.history) < 1: + return {"trend": 0.0, "volatility": 0.0} + + # Get all stats (history + current), excluding initial empty stats + all_history = self.history + [self.stats] + # Skip the first item if it's the initial empty stats (total_size == 0) + if all_history and all_history[0].total_size == 0 and len(all_history) > 1: + all_history = all_history[1:] + + # Apply window to the complete set + all_stats = all_history[-window:] if len(all_history) >= window else all_history + + # Calculate trend (simple linear regression slope) + if len(all_stats) < 2: + return {"trend": 0.0, "volatility": 0.0} + + sizes = [stats.total_size for stats in all_stats] + n = len(sizes) + x_mean = (n - 1) / 2 + y_mean = sum(sizes) / n + + numerator = sum((i - x_mean) * (sizes[i] - y_mean) for i in range(n)) + denominator = sum((i - x_mean) ** 2 for i in range(n)) + + trend = numerator / denominator if denominator != 0 else 0.0 + + # Calculate volatility (standard deviation of sizes) + variance = sum((size - y_mean) ** 2 for size in sizes) / n + volatility = variance**0.5 + + return { + "trend": trend, + "volatility": volatility, + "avg_size": y_mean, + "min_size": min(sizes), + "max_size": max(sizes), + } + + def should_cleanup(self, threshold: float, limit: int) -> bool: + """Determine if memory cleanup should be triggered.""" + utilization = self.stats.utilization_ratio(limit) + return utilization >= threshold + + def get_summary(self) -> Dict[str, Any]: + """Get a comprehensive summary of memory metrics.""" + trend_analysis = self.get_trend_analysis() + + return { + "current_stats": { + "total_size": self.stats.total_size, + "total_items": self.stats.total_items, + "distribution": self.stats.category_distribution(), + "hit_rate": self.stats.hit_rate, + "cleanup_count": self.stats.cleanup_count, + }, + "performance": { + "access_count": self.access_count, + "hit_rate": self.hit_count / self.access_count if self.access_count > 0 else 0.0, + "miss_rate": self.miss_count / self.access_count if self.access_count > 0 else 0.0, + "avg_response_time": self._calculate_avg_response_time(), + }, + "trends": trend_analysis, + "lifecycle": { + "promotions": self.stats.promotions, + "demotions": self.stats.demotions, + "archival_count": self.stats.archival_count, + }, + "timestamps": { + "creation_time": self.creation_time, + "last_access": self.last_access_time, + "last_cleanup": self.stats.last_cleanup, + }, + } + + def _calculate_avg_response_time(self) -> Optional[float]: + """Calculate average response time based on access patterns.""" + # Placeholder for future implementation + # Could track timing data for memory operations + return None diff --git a/tests/strands/agent/memory/__init__.py b/tests/strands/agent/memory/__init__.py new file mode 100644 index 000000000..f93d43409 --- /dev/null +++ b/tests/strands/agent/memory/__init__.py @@ -0,0 +1 @@ +"""Tests for memory management system.""" diff --git a/tests/strands/agent/memory/test_config.py b/tests/strands/agent/memory/test_config.py new file mode 100644 index 000000000..b2e97390d --- /dev/null +++ b/tests/strands/agent/memory/test_config.py @@ -0,0 +1,135 @@ +"""Tests for memory management configuration.""" + +from strands.agent.memory.config import MemoryCategory, MemoryConfig, MemoryThresholds + + +def test_memory_category_enum(): + """Test MemoryCategory enum values.""" + assert MemoryCategory.ACTIVE.value == "active" + assert MemoryCategory.CACHED.value == "cached" + assert MemoryCategory.ARCHIVED.value == "archived" + assert MemoryCategory.METADATA.value == "metadata" + + +def test_memory_thresholds_defaults(): + """Test MemoryThresholds default values.""" + thresholds = MemoryThresholds() + + assert thresholds.active_memory_limit == 8192 + assert thresholds.cached_memory_limit == 32768 + assert thresholds.total_memory_limit == 131072 + assert thresholds.cleanup_threshold == 0.8 + assert thresholds.emergency_threshold == 0.95 + assert thresholds.cache_ttl == 3600 + assert thresholds.archive_after == 86400 + assert thresholds.cleanup_ratio == 0.3 + assert thresholds.emergency_cleanup_ratio == 0.5 + + +def test_memory_thresholds_custom_values(): + """Test MemoryThresholds with custom values.""" + thresholds = MemoryThresholds( + active_memory_limit=4096, + cached_memory_limit=16384, + total_memory_limit=65536, + cleanup_threshold=0.7, + emergency_threshold=0.9, + ) + + assert thresholds.active_memory_limit == 4096 + assert thresholds.cached_memory_limit == 16384 + assert thresholds.total_memory_limit == 65536 + assert thresholds.cleanup_threshold == 0.7 + assert thresholds.emergency_threshold == 0.9 + + +def test_memory_config_defaults(): + """Test MemoryConfig default values.""" + config = MemoryConfig() + + assert config.enable_categorization is True + assert config.enable_lifecycle is True + assert config.enable_metrics is True + assert config.enable_archival is True + assert config.cleanup_strategy == "lru" + assert config.strict_validation is True + assert config.thresholds is not None + assert isinstance(config.thresholds, MemoryThresholds) + + +def test_memory_config_with_custom_thresholds(): + """Test MemoryConfig with custom thresholds.""" + thresholds = MemoryThresholds(active_memory_limit=2048) + config = MemoryConfig(thresholds=thresholds) + + assert config.thresholds.active_memory_limit == 2048 + assert config.thresholds.cached_memory_limit == 32768 # Default value + + +def test_memory_config_conservative(): + """Test conservative memory configuration.""" + config = MemoryConfig.conservative() + + assert config.thresholds.active_memory_limit == 4096 + assert config.thresholds.cached_memory_limit == 16384 + assert config.thresholds.total_memory_limit == 65536 + assert config.thresholds.cleanup_threshold == 0.7 + assert config.thresholds.cleanup_ratio == 0.4 + + +def test_memory_config_aggressive(): + """Test aggressive memory configuration.""" + config = MemoryConfig.aggressive() + + assert config.thresholds.active_memory_limit == 16384 + assert config.thresholds.cached_memory_limit == 65536 + assert config.thresholds.total_memory_limit == 262144 + assert config.thresholds.cleanup_threshold == 0.9 + assert config.thresholds.cleanup_ratio == 0.2 + + +def test_memory_config_minimal(): + """Test minimal memory configuration.""" + config = MemoryConfig.minimal() + + assert config.enable_lifecycle is False + assert config.enable_metrics is False + assert config.enable_archival is False + assert config.enable_categorization is True # Still enabled by default + + assert config.thresholds.active_memory_limit == 2048 + assert config.thresholds.cached_memory_limit == 8192 + assert config.thresholds.total_memory_limit == 32768 + + +def test_memory_config_custom_features(): + """Test MemoryConfig with custom feature toggles.""" + config = MemoryConfig( + enable_categorization=False, + enable_lifecycle=False, + enable_metrics=False, + enable_archival=False, + cleanup_strategy="fifo", + strict_validation=False, + ) + + assert config.enable_categorization is False + assert config.enable_lifecycle is False + assert config.enable_metrics is False + assert config.enable_archival is False + assert config.cleanup_strategy == "fifo" + assert config.strict_validation is False + + +def test_memory_config_post_init(): + """Test MemoryConfig __post_init__ method.""" + # Test with None thresholds (should create default) + config = MemoryConfig(thresholds=None) + assert config.thresholds is not None + assert isinstance(config.thresholds, MemoryThresholds) + + # Test with provided thresholds (should keep them) + custom_thresholds = MemoryThresholds(active_memory_limit=1024) + config = MemoryConfig(thresholds=custom_thresholds) + assert config.thresholds is custom_thresholds + assert config.thresholds.active_memory_limit == 1024 diff --git a/tests/strands/agent/memory/test_enhanced_state.py b/tests/strands/agent/memory/test_enhanced_state.py new file mode 100644 index 000000000..905051465 --- /dev/null +++ b/tests/strands/agent/memory/test_enhanced_state.py @@ -0,0 +1,328 @@ +"""Tests for enhanced agent state with memory management.""" + +import pytest + +from strands.agent.memory.config import MemoryCategory, MemoryConfig +from strands.agent.memory.enhanced_state import EnhancedAgentState + + +def test_enhanced_agent_state_initialization(): + """Test EnhancedAgentState initialization.""" + state = EnhancedAgentState() + + assert state.memory_config is not None + assert state.memory_manager is not None + assert isinstance(state.memory_config, MemoryConfig) + + +def test_enhanced_agent_state_with_initial_state(): + """Test EnhancedAgentState initialization with initial state.""" + initial_state = {"key1": "value1", "key2": "value2"} + state = EnhancedAgentState(initial_state=initial_state) + + # Should be available through both interfaces + assert state.get("key1") == "value1" + assert state.get("key2") == "value2" + + # Should be available through parent interface (backward compatibility) + parent_state = super(EnhancedAgentState, state).get() + assert parent_state["key1"] == "value1" + assert parent_state["key2"] == "value2" + + +def test_enhanced_agent_state_with_memory_config(): + """Test EnhancedAgentState with custom memory configuration.""" + config = MemoryConfig.conservative() + state = EnhancedAgentState(memory_config=config) + + assert state.memory_config is config + assert state.memory_manager.config is config + + +def test_enhanced_agent_state_set_get_basic(): + """Test basic set and get operations.""" + state = EnhancedAgentState() + + state.set("test_key", "test_value") + assert state.get("test_key") == "test_value" + + +def test_enhanced_agent_state_set_with_category(): + """Test set operation with memory category.""" + state = EnhancedAgentState() + + state.set("active_key", "active_value", MemoryCategory.ACTIVE) + state.set("cached_key", "cached_value", MemoryCategory.CACHED) + state.set("metadata_key", "metadata_value", MemoryCategory.METADATA) + + assert state.get("active_key") == "active_value" + assert state.get("cached_key") == "cached_value" + assert state.get("metadata_key") == "metadata_value" + + +def test_enhanced_agent_state_get_by_category(): + """Test retrieving items by category.""" + state = EnhancedAgentState() + + state.set("active1", "value1", MemoryCategory.ACTIVE) + state.set("active2", "value2", MemoryCategory.ACTIVE) + state.set("cached1", "value3", MemoryCategory.CACHED) + + active_items = state.get_by_category(MemoryCategory.ACTIVE) + cached_items = state.get_by_category(MemoryCategory.CACHED) + archived_items = state.get_by_category(MemoryCategory.ARCHIVED) + + assert len(active_items) == 2 + assert len(cached_items) == 1 + assert len(archived_items) == 0 + assert active_items["active1"] == "value1" + assert cached_items["cached1"] == "value3" + + +def test_enhanced_agent_state_get_by_category_disabled(): + """Test get_by_category when categorization is disabled.""" + config = MemoryConfig(enable_categorization=False) + state = EnhancedAgentState(memory_config=config) + + state.set("key1", "value1") + state.set("key2", "value2") + + # Should return all items regardless of category when disabled + items = state.get_by_category(MemoryCategory.ACTIVE) + assert len(items) == 2 + assert items["key1"] == "value1" + assert items["key2"] == "value2" + + +def test_enhanced_agent_state_get_entire_state(): + """Test getting entire state.""" + state = EnhancedAgentState() + + state.set("key1", "value1", MemoryCategory.ACTIVE) + state.set("key2", "value2", MemoryCategory.CACHED) + + # Get entire state + all_items = state.get() + assert len(all_items) == 2 + assert all_items["key1"] == "value1" + assert all_items["key2"] == "value2" + + +def test_enhanced_agent_state_get_filtered_by_category(): + """Test getting state filtered by category.""" + state = EnhancedAgentState() + + state.set("active1", "value1", MemoryCategory.ACTIVE) + state.set("cached1", "value2", MemoryCategory.CACHED) + + # Get only active items + active_items = state.get(category=MemoryCategory.ACTIVE) + assert len(active_items) == 1 + assert active_items["active1"] == "value1" + + # Get only cached items + cached_items = state.get(category=MemoryCategory.CACHED) + assert len(cached_items) == 1 + assert cached_items["cached1"] == "value2" + + +def test_enhanced_agent_state_delete(): + """Test deleting items.""" + state = EnhancedAgentState() + + state.set("key1", "value1") + state.set("key2", "value2") + + state.delete("key1") + + assert state.get("key1") is None + assert state.get("key2") == "value2" + + # Should also be deleted from parent state + parent_state = super(EnhancedAgentState, state).get() + assert "key1" not in parent_state + assert parent_state["key2"] == "value2" + + +def test_enhanced_agent_state_convenience_methods(): + """Test convenience methods for memory categories.""" + state = EnhancedAgentState() + + # Test set_metadata + state.set_metadata("meta_key", "meta_value") + metadata_items = state.get_by_category(MemoryCategory.METADATA) + assert metadata_items["meta_key"] == "meta_value" + + # Test get_active_memory + state.set("active_key", "active_value", MemoryCategory.ACTIVE) + active_items = state.get_active_memory() + assert active_items["active_key"] == "active_value" + + # Test get_cached_memory + state.set("cached_key", "cached_value", MemoryCategory.CACHED) + cached_items = state.get_cached_memory() + assert cached_items["cached_key"] == "cached_value" + + +def test_enhanced_agent_state_cleanup_memory(): + """Test memory cleanup functionality.""" + state = EnhancedAgentState() + + state.set("key1", "value1") + state.set("key2", "value2") + + # Test cleanup (might not remove anything without time passage) + removed_count = state.cleanup_memory() + assert removed_count >= 0 + + # Test forced cleanup + removed_count = state.cleanup_memory(force=True) + assert removed_count >= 0 + + +def test_enhanced_agent_state_cleanup_memory_disabled(): + """Test cleanup when lifecycle management is disabled.""" + config = MemoryConfig(enable_lifecycle=False) + state = EnhancedAgentState(memory_config=config) + + state.set("key1", "value1") + + # Should return 0 when lifecycle disabled + removed_count = state.cleanup_memory() + assert removed_count == 0 + + # Should work with force=True + removed_count = state.cleanup_memory(force=True) + assert removed_count >= 0 + + +def test_enhanced_agent_state_get_memory_stats(): + """Test getting memory statistics.""" + state = EnhancedAgentState() + + state.set("key1", "value1", MemoryCategory.ACTIVE) + state.set("key2", "value2", MemoryCategory.CACHED) + + stats = state.get_memory_stats() + + assert "current_stats" in stats or "total_items" in stats + + # Test with metrics disabled + config = MemoryConfig(enable_metrics=False) + state_no_metrics = EnhancedAgentState(memory_config=config) + stats = state_no_metrics.get_memory_stats() + + assert stats["metrics_enabled"] is False + assert "total_items" in stats + + +def test_enhanced_agent_state_optimize_memory(): + """Test memory optimization.""" + state = EnhancedAgentState() + + state.set("key1", "value1") + state.set("key2", "value2") + + optimization_results = state.optimize_memory() + + # Should return optimization results or skip info + assert isinstance(optimization_results, dict) + + # Test with lifecycle disabled + config = MemoryConfig(enable_lifecycle=False) + state_no_lifecycle = EnhancedAgentState(memory_config=config) + results = state_no_lifecycle.optimize_memory() + + assert results.get("optimization_skipped") is True + assert results.get("reason") == "lifecycle_disabled" + + +def test_enhanced_agent_state_configure_memory(): + """Test memory configuration updates.""" + state = EnhancedAgentState() + original_config = state.get_memory_config() + + new_config = MemoryConfig.aggressive() + state.configure_memory(new_config) + + assert state.get_memory_config() is new_config + assert state.memory_manager.config is new_config + assert state.get_memory_config() is not original_config + + +def test_enhanced_agent_state_backward_compatibility(): + """Test backward compatibility with base AgentState interface.""" + state = EnhancedAgentState() + + # All base AgentState operations should work + state.set("key1", "value1") + assert state.get("key1") == "value1" + + # Get entire state should work + all_state = state.get() + assert all_state["key1"] == "value1" + + # Delete should work + state.delete("key1") + assert state.get("key1") is None + + # Validation should still work + with pytest.raises(ValueError, match="Key cannot be None"): + state.set(None, "value") + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("key", lambda x: x) + + +def test_enhanced_agent_state_with_categorization_disabled(): + """Test behavior when categorization is disabled.""" + config = MemoryConfig(enable_categorization=False) + state = EnhancedAgentState(memory_config=config) + + # Set operations should still work + state.set("key1", "value1") + state.set("key2", "value2", MemoryCategory.CACHED) # Category ignored + + # Get should work normally + assert state.get("key1") == "value1" + assert state.get("key2") == "value2" + + # Should use parent implementation when categorization disabled + all_items = state.get() + assert all_items["key1"] == "value1" + assert all_items["key2"] == "value2" + + +def test_enhanced_agent_state_json_validation(): + """Test that JSON validation is maintained.""" + state = EnhancedAgentState() + + # Valid JSON types should work + state.set("string", "test") + state.set("int", 42) + state.set("bool", True) + state.set("list", [1, 2, 3]) + state.set("dict", {"nested": "value"}) + state.set("null", None) + + # Invalid JSON types should raise ValueError + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("function", lambda x: x) + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("object", object()) + + +def test_enhanced_agent_state_key_validation(): + """Test that key validation is maintained.""" + state = EnhancedAgentState() + + # Invalid keys should raise ValueError + with pytest.raises(ValueError, match="Key cannot be None"): + state.set(None, "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.set("", "value") + + with pytest.raises(ValueError, match="Key must be a string"): + state.set(123, "value") diff --git a/tests/strands/agent/memory/test_lifecycle.py b/tests/strands/agent/memory/test_lifecycle.py new file mode 100644 index 000000000..549f845b4 --- /dev/null +++ b/tests/strands/agent/memory/test_lifecycle.py @@ -0,0 +1,348 @@ +"""Tests for memory lifecycle management.""" + +import time +from unittest.mock import patch + +from strands.agent.memory.config import MemoryCategory, MemoryConfig, MemoryThresholds +from strands.agent.memory.lifecycle import CategorizedMemoryItem, MemoryLifecycleManager + + +def test_categorized_memory_item_creation(): + """Test CategorizedMemoryItem creation and properties.""" + item = CategorizedMemoryItem("test_key", "test_value") + + assert item.key == "test_key" + assert item.value == "test_value" + assert item.category == MemoryCategory.ACTIVE + assert item.access_count == 0 + assert item.size > 0 + assert item.created_at <= time.time() + assert item.last_accessed <= time.time() + + +def test_categorized_memory_item_with_category(): + """Test CategorizedMemoryItem with specific category.""" + item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.CACHED) + + assert item.category == MemoryCategory.CACHED + + +def test_categorized_memory_item_access(): + """Test memory item access tracking.""" + item = CategorizedMemoryItem("test_key", "test_value") + initial_access_time = item.last_accessed + initial_access_count = item.access_count + + time.sleep(0.01) # Small delay to ensure time difference + item.access() + + assert item.access_count == initial_access_count + 1 + assert item.last_accessed > initial_access_time + + +def test_categorized_memory_item_age(): + """Test memory item age calculation.""" + with patch("time.time") as mock_time: + mock_time.return_value = 1000.0 + item = CategorizedMemoryItem("test_key", "test_value") + + mock_time.return_value = 1010.0 + assert item.age() == 10.0 + + +def test_categorized_memory_item_idle_time(): + """Test memory item idle time calculation.""" + with patch("time.time") as mock_time: + mock_time.return_value = 1000.0 + item = CategorizedMemoryItem("test_key", "test_value") + + mock_time.return_value = 1005.0 + item.access() + + mock_time.return_value = 1015.0 + assert item.idle_time() == 10.0 + + +def test_categorized_memory_item_should_demote(): + """Test demotion logic for memory items.""" + item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.ACTIVE) + + with patch.object(item, "idle_time", return_value=3700): # > 1 hour + assert item.should_demote(3600) is True + + with patch.object(item, "idle_time", return_value=1800): # < 1 hour + assert item.should_demote(3600) is False + + # Non-active items should not be demoted + item.category = MemoryCategory.CACHED + with patch.object(item, "idle_time", return_value=3700): + assert item.should_demote(3600) is False + + +def test_categorized_memory_item_should_archive(): + """Test archival logic for memory items.""" + item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.CACHED) + + with patch.object(item, "age", return_value=90000): # > 24 hours + assert item.should_archive(86400) is True + + with patch.object(item, "age", return_value=43200): # < 24 hours + assert item.should_archive(86400) is False + + # Non-cached items should not be archived + item.category = MemoryCategory.ACTIVE + with patch.object(item, "age", return_value=90000): + assert item.should_archive(86400) is False + + +def test_categorized_memory_item_size_estimation(): + """Test size estimation for different value types.""" + # Test with different value types + string_item = CategorizedMemoryItem("key", "hello") + dict_item = CategorizedMemoryItem("key", {"nested": "data"}) + list_item = CategorizedMemoryItem("key", [1, 2, 3, 4, 5]) + + assert string_item.size > 0 + assert dict_item.size > 0 + assert list_item.size > 0 + + # Larger objects should have larger estimated sizes + large_string = "x" * 1000 + large_item = CategorizedMemoryItem("key", large_string) + assert large_item.size > string_item.size + + +def test_memory_lifecycle_manager_initialization(): + """Test MemoryLifecycleManager initialization.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + assert manager.config is config + assert manager._items == {} + assert manager._last_cleanup <= time.time() + + +def test_memory_lifecycle_manager_add_item(): + """Test adding items to memory manager.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1") + manager.add_item("key2", "value2", MemoryCategory.CACHED) + + assert "key1" in manager._items + assert "key2" in manager._items + assert manager._items["key1"].category == MemoryCategory.ACTIVE + assert manager._items["key2"].category == MemoryCategory.CACHED + assert manager._items["key1"].value == "value1" + assert manager._items["key2"].value == "value2" + + +def test_memory_lifecycle_manager_get_item(): + """Test retrieving items from memory manager.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("existing_key", "existing_value") + + # Test existing key + value = manager.get_item("existing_key") + assert value == "existing_value" + + # Test non-existing key + value = manager.get_item("non_existing_key") + assert value is None + + # Check that access was recorded + item = manager._items["existing_key"] + assert item.access_count == 1 + + +def test_memory_lifecycle_manager_remove_item(): + """Test removing items from memory manager.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1") + + # Remove existing item + assert manager.remove_item("key1") is True + assert "key1" not in manager._items + + # Remove non-existing item + assert manager.remove_item("non_existing") is False + + +def test_memory_lifecycle_manager_get_items_by_category(): + """Test retrieving items by category.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("active1", "value1", MemoryCategory.ACTIVE) + manager.add_item("active2", "value2", MemoryCategory.ACTIVE) + manager.add_item("cached1", "value3", MemoryCategory.CACHED) + + active_items = manager.get_items_by_category(MemoryCategory.ACTIVE) + cached_items = manager.get_items_by_category(MemoryCategory.CACHED) + archived_items = manager.get_items_by_category(MemoryCategory.ARCHIVED) + + assert len(active_items) == 2 + assert len(cached_items) == 1 + assert len(archived_items) == 0 + assert active_items["active1"] == "value1" + assert cached_items["cached1"] == "value3" + + +def test_memory_lifecycle_manager_get_all_items(): + """Test retrieving all items.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1", MemoryCategory.ACTIVE) + manager.add_item("key2", "value2", MemoryCategory.CACHED) + + all_items = manager.get_all_items() + + assert len(all_items) == 2 + assert all_items["key1"] == "value1" + assert all_items["key2"] == "value2" + + +def test_memory_lifecycle_manager_promotion(): + """Test automatic promotion of frequently accessed items.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + # Add cached item + manager.add_item("key1", "value1", MemoryCategory.CACHED) + + # Mock frequent access pattern + with patch.object(manager, "_should_promote", return_value=True): + # Access should trigger promotion + manager.get_item("key1") + + # Check that item was promoted + assert manager._items["key1"].category == MemoryCategory.ACTIVE + + +def test_memory_lifecycle_manager_should_promote(): + """Test promotion logic.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + # Create item with access pattern + item = CategorizedMemoryItem("key", "value", MemoryCategory.CACHED) + + # Test promotion logic + with patch.object(item, "age", return_value=7200): # 2 hours + with patch.object(item, "access_count", 3): # 3 accesses in 2 hours > 1/hour + assert manager._should_promote(item) is True + + with patch.object(item, "access_count", 1): # 1 access in 2 hours < 1/hour + assert manager._should_promote(item) is False + + # Test with new item (age = 0) + with patch.object(item, "age", return_value=0): + assert manager._should_promote(item) is False + + +def test_memory_lifecycle_manager_cleanup_memory(): + """Test memory cleanup functionality.""" + thresholds = MemoryThresholds( + cache_ttl=1800, # 30 minutes + archive_after=3600, # 1 hour + ) + config = MemoryConfig(thresholds=thresholds, enable_lifecycle=True) + manager = MemoryLifecycleManager(config) + + # Add items with different ages + manager.add_item("active_old", "value1", MemoryCategory.ACTIVE) + manager.add_item("cached_old", "value2", MemoryCategory.CACHED) + + # Mock ages for demotion/archival + active_item = manager._items["active_old"] + cached_item = manager._items["cached_old"] + + with patch.object(active_item, "should_demote", return_value=True): + with patch.object(cached_item, "should_archive", return_value=True): + # Run cleanup + manager.cleanup_memory() + + # Check changes + assert active_item.category == MemoryCategory.CACHED + assert cached_item.category == MemoryCategory.ARCHIVED + + +def test_memory_lifecycle_manager_emergency_cleanup(): + """Test emergency cleanup when memory limits exceeded.""" + thresholds = MemoryThresholds(total_memory_limit=100) # Very small limit + config = MemoryConfig(thresholds=thresholds, enable_lifecycle=True) + manager = MemoryLifecycleManager(config) + + # Add items that exceed limit + large_value = "x" * 50 + manager.add_item("key1", large_value, MemoryCategory.CACHED) + manager.add_item("key2", large_value, MemoryCategory.CACHED) + manager.add_item("key3", large_value, MemoryCategory.ACTIVE) + + initial_count = len(manager._items) + + # Force cleanup + removed_count = manager.cleanup_memory(force=True) + + # Should have removed some items + assert removed_count > 0 + assert len(manager._items) < initial_count + + +def test_memory_lifecycle_manager_disabled_lifecycle(): + """Test behavior when lifecycle management is disabled.""" + config = MemoryConfig(enable_lifecycle=False) + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1") + + # Cleanup should do nothing when lifecycle disabled + removed_count = manager.cleanup_memory(force=False) + assert removed_count == 0 + + # But should work with force=True + removed_count = manager.cleanup_memory(force=True) + assert removed_count >= 0 # May or may not remove items + + +def test_memory_lifecycle_manager_get_memory_report(): + """Test memory usage report generation.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1", MemoryCategory.ACTIVE) + manager.add_item("key2", "value2", MemoryCategory.CACHED) + + report = manager.get_memory_report() + + assert "current_stats" in report + assert "performance" in report + assert "trends" in report + assert "lifecycle" in report + assert "timestamps" in report + + +def test_memory_lifecycle_manager_optimize_memory(): + """Test memory optimization.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + # Add some items + manager.add_item("key1", "value1") + manager.add_item("key2", "value2") + + optimization_results = manager.optimize_memory() + + assert "initial_size" in optimization_results + assert "final_size" in optimization_results + assert "size_reduction" in optimization_results + assert "initial_count" in optimization_results + assert "final_count" in optimization_results + assert "items_removed" in optimization_results + assert "size_reduction_pct" in optimization_results diff --git a/tests/strands/agent/memory/test_metrics.py b/tests/strands/agent/memory/test_metrics.py new file mode 100644 index 000000000..af747e8a3 --- /dev/null +++ b/tests/strands/agent/memory/test_metrics.py @@ -0,0 +1,297 @@ +"""Tests for memory metrics and monitoring.""" + +import time + +from strands.agent.memory.config import MemoryCategory +from strands.agent.memory.metrics import MemoryMetrics, MemoryUsageStats + + +def test_memory_usage_stats_defaults(): + """Test MemoryUsageStats default values.""" + stats = MemoryUsageStats() + + assert stats.total_size == 0 + assert stats.active_size == 0 + assert stats.cached_size == 0 + assert stats.archived_size == 0 + assert stats.metadata_size == 0 + assert stats.total_items == 0 + assert stats.hit_rate == 0.0 + assert stats.cleanup_count == 0 + assert stats.promotions == 0 + + +def test_memory_usage_stats_utilization_ratio(): + """Test utilization ratio calculation.""" + stats = MemoryUsageStats() + stats.total_size = 5000 + + # Normal case + assert stats.utilization_ratio(10000) == 0.5 + + # Over limit case + assert stats.utilization_ratio(2500) == 1.0 + + # Zero limit case + assert stats.utilization_ratio(0) == 0.0 + + +def test_memory_usage_stats_category_distribution(): + """Test category distribution calculation.""" + stats = MemoryUsageStats() + stats.total_size = 1000 + stats.active_size = 400 + stats.cached_size = 300 + stats.archived_size = 200 + stats.metadata_size = 100 + + distribution = stats.category_distribution() + + assert distribution[MemoryCategory.ACTIVE.value] == 0.4 + assert distribution[MemoryCategory.CACHED.value] == 0.3 + assert distribution[MemoryCategory.ARCHIVED.value] == 0.2 + assert distribution[MemoryCategory.METADATA.value] == 0.1 + + +def test_memory_usage_stats_category_distribution_empty(): + """Test category distribution with zero total size.""" + stats = MemoryUsageStats() + # total_size = 0 by default + + distribution = stats.category_distribution() + + for category in MemoryCategory: + assert distribution[category.value] == 0.0 + + +def test_memory_metrics_initialization(): + """Test MemoryMetrics initialization.""" + metrics = MemoryMetrics() + + assert isinstance(metrics.stats, MemoryUsageStats) + assert metrics.history == [] + assert metrics.max_history_size == 100 + assert metrics.access_count == 0 + assert metrics.hit_count == 0 + assert metrics.miss_count == 0 + assert metrics.last_access_time is None + assert metrics.creation_time <= time.time() + + +def test_memory_metrics_record_access(): + """Test recording memory access (hits and misses).""" + metrics = MemoryMetrics() + + # Record hits + metrics.record_access(hit=True) + metrics.record_access(hit=True) + + assert metrics.access_count == 2 + assert metrics.hit_count == 2 + assert metrics.miss_count == 0 + assert metrics.stats.hit_rate == 1.0 + assert metrics.stats.miss_rate == 0.0 + assert metrics.last_access_time is not None + + # Record miss + metrics.record_access(hit=False) + + assert metrics.access_count == 3 + assert metrics.hit_count == 2 + assert metrics.miss_count == 1 + assert metrics.stats.hit_rate == 2 / 3 + assert metrics.stats.miss_rate == 1 / 3 + + +def test_memory_metrics_record_operations(): + """Test recording various memory operations.""" + metrics = MemoryMetrics() + + # Test cleanup recording + metrics.record_cleanup() + assert metrics.stats.cleanup_count == 1 + assert metrics.stats.last_cleanup is not None + + # Test promotion recording + metrics.record_promotion() + assert metrics.stats.promotions == 1 + + # Test demotion recording + metrics.record_demotion() + assert metrics.stats.demotions == 1 + + # Test archival recording + metrics.record_archival() + assert metrics.stats.archival_count == 1 + + +def test_memory_metrics_update_stats(): + """Test updating statistics with history tracking.""" + metrics = MemoryMetrics() + + # Create new stats + new_stats = MemoryUsageStats() + new_stats.total_size = 1000 + new_stats.total_items = 10 + + # Update stats + metrics.update_stats(new_stats) + + # Check that old stats were saved to history + assert len(metrics.history) == 1 + assert metrics.history[0].total_size == 0 # Old default stats + + # Check that new stats are current + assert metrics.stats.total_size == 1000 + assert metrics.stats.total_items == 10 + + +def test_memory_metrics_history_limit(): + """Test that history is limited to max_history_size.""" + metrics = MemoryMetrics() + metrics.max_history_size = 3 + + # Add more stats than the limit + for i in range(5): + new_stats = MemoryUsageStats() + new_stats.total_size = i * 100 + metrics.update_stats(new_stats) + + # Check that history is limited + assert len(metrics.history) == 3 + + # Check that oldest entries were removed (should have sizes 100, 200, 300) + assert metrics.history[0].total_size == 100 + assert metrics.history[1].total_size == 200 + assert metrics.history[2].total_size == 300 + + +def test_memory_metrics_estimate_item_size(): + """Test item size estimation.""" + metrics = MemoryMetrics() + + # Test with JSON-serializable objects + assert metrics.estimate_item_size("hello") > 0 + assert metrics.estimate_item_size({"key": "value"}) > 0 + assert metrics.estimate_item_size([1, 2, 3]) > 0 + + # Test that larger objects have larger estimated sizes + small_obj = "hi" + large_obj = "hello world" * 100 + assert metrics.estimate_item_size(large_obj) > metrics.estimate_item_size(small_obj) + + # Test with non-serializable object + class CustomObject: + def __str__(self): + return "custom object" + + custom_obj = CustomObject() + size = metrics.estimate_item_size(custom_obj) + assert size > 0 # Should fallback to string representation + + +def test_memory_metrics_trend_analysis(): + """Test trend analysis functionality.""" + metrics = MemoryMetrics() + + # Test with insufficient history + trend = metrics.get_trend_analysis() + assert trend["trend"] == 0.0 + assert trend["volatility"] == 0.0 + + # Add some history with increasing sizes + sizes = [100, 200, 300, 400, 500] + for size in sizes: + stats = MemoryUsageStats() + stats.total_size = size + metrics.update_stats(stats) + + trend = metrics.get_trend_analysis() + assert trend["trend"] > 0 # Should show upward trend + assert trend["avg_size"] == 300 # Average of 100, 200, 300, 400, 500 + assert trend["min_size"] == 100 + assert trend["max_size"] == 500 + assert trend["volatility"] > 0 + + +def test_memory_metrics_trend_analysis_with_window(): + """Test trend analysis with custom window size.""" + metrics = MemoryMetrics() + + # Add history + sizes = [100, 200, 300, 400, 500, 600] + for size in sizes: + stats = MemoryUsageStats() + stats.total_size = size + metrics.update_stats(stats) + + # Analyze with smaller window + trend = metrics.get_trend_analysis(window=3) + # Should only consider last 3 values: 400, 500, 600 + assert trend["avg_size"] == 500 + assert trend["min_size"] == 400 + assert trend["max_size"] == 600 + + +def test_memory_metrics_should_cleanup(): + """Test cleanup decision logic.""" + metrics = MemoryMetrics() + + # Setup stats + metrics.stats.total_size = 8000 + + # Should cleanup when over threshold + assert metrics.should_cleanup(threshold=0.8, limit=10000) is True + + # Should not cleanup when under threshold + assert metrics.should_cleanup(threshold=0.9, limit=10000) is False + + # Edge case: zero limit + assert metrics.should_cleanup(threshold=0.5, limit=0) is False + + +def test_memory_metrics_get_summary(): + """Test comprehensive summary generation.""" + metrics = MemoryMetrics() + + # Setup some data + metrics.record_access(hit=True) + metrics.record_access(hit=False) + metrics.record_cleanup() + metrics.record_promotion() + + # Add some history for trend analysis + for i in range(3): + stats = MemoryUsageStats() + stats.total_size = (i + 1) * 100 + metrics.update_stats(stats) + + summary = metrics.get_summary() + + # Check structure + assert "current_stats" in summary + assert "performance" in summary + assert "trends" in summary + assert "lifecycle" in summary + assert "timestamps" in summary + + # Check current stats + current_stats = summary["current_stats"] + assert "total_size" in current_stats + assert "distribution" in current_stats + assert "hit_rate" in current_stats + + # Check performance metrics + performance = summary["performance"] + assert performance["access_count"] == 2 + assert performance["hit_rate"] == 0.5 + assert performance["miss_rate"] == 0.5 + + # Check lifecycle metrics + lifecycle = summary["lifecycle"] + assert lifecycle["promotions"] == 1 + + # Check timestamps + timestamps = summary["timestamps"] + assert "creation_time" in timestamps + assert "last_access" in timestamps