From d585c0e31b5f6d88b881cd0d7b40b17f82ec43af Mon Sep 17 00:00:00 2001 From: yike5460 Date: Thu, 7 Aug 2025 17:04:27 +0800 Subject: [PATCH 1/2] feat: add enhanced memory management foundation * Add tiered memory categorization system (ACTIVE, CACHED, ARCHIVED, METADATA) * Implement EnhancedAgentState with backward compatibility to existing AgentState * Add automatic memory lifecycle management with promotion, demotion, and archival * Include comprehensive metrics tracking with trend analysis and performance monitoring * Provide configurable memory thresholds and cleanup strategies * Add emergency cleanup for memory limit enforcement * Maintain full JSON validation and key validation from parent classes Breaking: None - fully backward compatible with existing AgentState interface Testing: 65 comprehensive unit tests with 100% pass rate --- src/strands/agent/memory/__init__.py | 27 ++ src/strands/agent/memory/config.py | 101 +++++ src/strands/agent/memory/enhanced_state.py | 221 +++++++++++ src/strands/agent/memory/lifecycle.py | 249 +++++++++++++ src/strands/agent/memory/metrics.py | 221 +++++++++++ tests/strands/agent/memory/__init__.py | 1 + tests/strands/agent/memory/test_config.py | 137 +++++++ .../agent/memory/test_enhanced_state.py | 328 ++++++++++++++++ tests/strands/agent/memory/test_lifecycle.py | 351 ++++++++++++++++++ tests/strands/agent/memory/test_metrics.py | 301 +++++++++++++++ 10 files changed, 1937 insertions(+) create mode 100644 src/strands/agent/memory/__init__.py create mode 100644 src/strands/agent/memory/config.py create mode 100644 src/strands/agent/memory/enhanced_state.py create mode 100644 src/strands/agent/memory/lifecycle.py create mode 100644 src/strands/agent/memory/metrics.py create mode 100644 tests/strands/agent/memory/__init__.py create mode 100644 tests/strands/agent/memory/test_config.py create mode 100644 tests/strands/agent/memory/test_enhanced_state.py create mode 100644 tests/strands/agent/memory/test_lifecycle.py create mode 100644 tests/strands/agent/memory/test_metrics.py diff --git a/src/strands/agent/memory/__init__.py b/src/strands/agent/memory/__init__.py new file mode 100644 index 000000000..89d41e134 --- /dev/null +++ b/src/strands/agent/memory/__init__.py @@ -0,0 +1,27 @@ +"""Memory management system for agents. + +This package provides enhanced memory management capabilities including: +- Memory categorization (active, cached, archived) +- Memory lifecycle management with automatic cleanup +- Memory usage monitoring and metrics +- Configurable memory thresholds and policies + +The memory management system is designed to be backward compatible with +existing AgentState usage while providing advanced memory optimization +capabilities for complex multi-agent scenarios. +""" + +from .config import MemoryConfig, MemoryCategory, MemoryThresholds +from .enhanced_state import EnhancedAgentState +from .lifecycle import MemoryLifecycleManager +from .metrics import MemoryMetrics, MemoryUsageStats + +__all__ = [ + "MemoryConfig", + "MemoryCategory", + "MemoryThresholds", + "EnhancedAgentState", + "MemoryLifecycleManager", + "MemoryMetrics", + "MemoryUsageStats", +] \ No newline at end of file diff --git a/src/strands/agent/memory/config.py b/src/strands/agent/memory/config.py new file mode 100644 index 000000000..a427dd850 --- /dev/null +++ b/src/strands/agent/memory/config.py @@ -0,0 +1,101 @@ +"""Memory management configuration and types.""" + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class MemoryCategory(Enum): + """Categories for memory classification.""" + + ACTIVE = "active" # Currently active/working memory + CACHED = "cached" # Recently used but not active + ARCHIVED = "archived" # Historical data for potential retrieval + METADATA = "metadata" # System metadata and statistics + + +@dataclass +class MemoryThresholds: + """Memory management thresholds and limits.""" + + # Size thresholds (in estimated tokens/bytes) + active_memory_limit: int = 8192 # ~8K tokens for active memory + cached_memory_limit: int = 32768 # ~32K tokens for cached memory + total_memory_limit: int = 131072 # ~128K tokens total limit + + # Cleanup thresholds (percentages) + cleanup_threshold: float = 0.8 # Start cleanup at 80% of limit + emergency_threshold: float = 0.95 # Emergency cleanup at 95% + + # Time-based thresholds (seconds) + cache_ttl: int = 3600 # Cache TTL: 1 hour + archive_after: int = 86400 # Archive after: 24 hours + + # Cleanup ratios (how much to remove during cleanup) + cleanup_ratio: float = 0.3 # Remove 30% during cleanup + emergency_cleanup_ratio: float = 0.5 # Remove 50% during emergency + + +@dataclass +class MemoryConfig: + """Configuration for memory management system.""" + + # Feature toggles + enable_categorization: bool = True # Enable memory categorization + enable_lifecycle: bool = True # Enable automatic lifecycle management + enable_metrics: bool = True # Enable memory metrics collection + enable_archival: bool = True # Enable memory archival + + # Thresholds configuration + thresholds: MemoryThresholds = None + + # Cleanup strategy + cleanup_strategy: str = "lru" # LRU, FIFO, or custom + + # Validation settings + strict_validation: bool = True # Strict JSON validation + + def __post_init__(self): + """Initialize default thresholds if not provided.""" + if self.thresholds is None: + self.thresholds = MemoryThresholds() + + @classmethod + def conservative(cls) -> "MemoryConfig": + """Create conservative memory configuration with lower limits.""" + return cls( + thresholds=MemoryThresholds( + active_memory_limit=4096, + cached_memory_limit=16384, + total_memory_limit=65536, + cleanup_threshold=0.7, + cleanup_ratio=0.4 + ) + ) + + @classmethod + def aggressive(cls) -> "MemoryConfig": + """Create aggressive memory configuration with higher limits.""" + return cls( + thresholds=MemoryThresholds( + active_memory_limit=16384, + cached_memory_limit=65536, + total_memory_limit=262144, + cleanup_threshold=0.9, + cleanup_ratio=0.2 + ) + ) + + @classmethod + def minimal(cls) -> "MemoryConfig": + """Create minimal memory configuration with basic features only.""" + return cls( + enable_lifecycle=False, + enable_metrics=False, + enable_archival=False, + thresholds=MemoryThresholds( + active_memory_limit=2048, + cached_memory_limit=8192, + total_memory_limit=32768 + ) + ) \ No newline at end of file diff --git a/src/strands/agent/memory/enhanced_state.py b/src/strands/agent/memory/enhanced_state.py new file mode 100644 index 000000000..c09f9415c --- /dev/null +++ b/src/strands/agent/memory/enhanced_state.py @@ -0,0 +1,221 @@ +"""Enhanced agent state with memory management capabilities.""" + +import copy +from typing import Any, Dict, Optional + +from ..state import AgentState +from .config import MemoryCategory, MemoryConfig +from .lifecycle import MemoryLifecycleManager + + +class EnhancedAgentState(AgentState): + """Enhanced AgentState with memory categorization, lifecycle management, and metrics. + + This class extends the base AgentState to provide: + - Memory categorization (active, cached, archived, metadata) + - Automatic memory lifecycle management + - Memory usage monitoring and metrics + - Configurable memory thresholds and cleanup policies + + The enhanced state maintains full backward compatibility with the base AgentState + interface while adding advanced memory management capabilities. + """ + + def __init__( + self, + initial_state: Optional[Dict[str, Any]] = None, + memory_config: Optional[MemoryConfig] = None + ): + """Initialize EnhancedAgentState. + + Args: + initial_state: Initial state dictionary (backward compatibility) + memory_config: Memory management configuration + """ + # Initialize base AgentState for backward compatibility + super().__init__(initial_state) + + # Initialize memory management + self.memory_config = memory_config or MemoryConfig() + self.memory_manager = MemoryLifecycleManager(self.memory_config) + + # Migrate existing state to memory manager if provided + if initial_state: + for key, value in initial_state.items(): + self.memory_manager.add_item(key, value, MemoryCategory.ACTIVE) + + def set(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE) -> None: + """Set a value in the state with optional memory category. + + Args: + key: The key to store the value under + value: The value to store (must be JSON serializable) + category: Memory category for the value + + Raises: + ValueError: If key is invalid, or if value is not JSON serializable + """ + # Validate using parent class methods + self._validate_key(key) + self._validate_json_serializable(value) + + # Store in both base state (for backward compatibility) and memory manager + super().set(key, value) + + if self.memory_config.enable_categorization: + self.memory_manager.add_item(key, value, category) + + def get(self, key: Optional[str] = None, category: Optional[MemoryCategory] = None) -> Any: + """Get a value or entire state, optionally filtered by category. + + Args: + key: The key to retrieve (if None, returns entire state object) + category: Optional memory category filter + + Returns: + The stored value, filtered state dict, or None if not found + """ + if key is None: + # Return entire state, optionally filtered by category + if category is not None and self.memory_config.enable_categorization: + return copy.deepcopy(self.memory_manager.get_items_by_category(category)) + else: + # Backward compatibility: return base state + return super().get() + else: + # Return specific key + if self.memory_config.enable_categorization: + return self.memory_manager.get_item(key) + else: + return super().get(key) + + def delete(self, key: str) -> None: + """Delete a specific key from the state. + + Args: + key: The key to delete + """ + self._validate_key(key) + + # Delete from both base state and memory manager + super().delete(key) + + if self.memory_config.enable_categorization: + self.memory_manager.remove_item(key) + + def get_by_category(self, category: MemoryCategory) -> Dict[str, Any]: + """Get all items in a specific memory category. + + Args: + category: The memory category to retrieve + + Returns: + Dictionary of all items in the specified category + """ + if not self.memory_config.enable_categorization: + # If categorization disabled, return all items for any category + return super().get() or {} + + return self.memory_manager.get_items_by_category(category) + + def set_metadata(self, key: str, value: Any) -> None: + """Set a metadata value (convenience method). + + Args: + key: The key to store the metadata under + value: The metadata value + """ + self.set(key, value, MemoryCategory.METADATA) + + def get_active_memory(self) -> Dict[str, Any]: + """Get all active memory items (convenience method). + + Returns: + Dictionary of all active memory items + """ + return self.get_by_category(MemoryCategory.ACTIVE) + + def get_cached_memory(self) -> Dict[str, Any]: + """Get all cached memory items (convenience method). + + Returns: + Dictionary of all cached memory items + """ + return self.get_by_category(MemoryCategory.CACHED) + + def cleanup_memory(self, force: bool = False) -> int: + """Perform memory cleanup and return number of items removed. + + Args: + force: Force cleanup even if lifecycle management is disabled + + Returns: + Number of items removed during cleanup + """ + if not self.memory_config.enable_lifecycle and not force: + return 0 + + removed_count = self.memory_manager.cleanup_memory(force) + + # Sync base state with memory manager after cleanup + self._sync_base_state() + + return removed_count + + def get_memory_stats(self) -> Dict[str, Any]: + """Get comprehensive memory usage statistics. + + Returns: + Dictionary containing memory usage statistics and metrics + """ + if not self.memory_config.enable_metrics: + # Return basic stats if metrics disabled + all_items = super().get() or {} + return { + "total_items": len(all_items), + "categories_enabled": False, + "lifecycle_enabled": self.memory_config.enable_lifecycle, + "metrics_enabled": False + } + + return self.memory_manager.get_memory_report() + + def optimize_memory(self) -> Dict[str, Any]: + """Optimize memory usage and return optimization results. + + Returns: + Dictionary containing optimization results and statistics + """ + if not self.memory_config.enable_lifecycle: + return {"optimization_skipped": True, "reason": "lifecycle_disabled"} + + optimization_results = self.memory_manager.optimize_memory() + + # Sync base state after optimization + self._sync_base_state() + + return optimization_results + + def _sync_base_state(self) -> None: + """Synchronize base state with memory manager state.""" + if self.memory_config.enable_categorization: + # Update base state to match memory manager + all_items = self.memory_manager.get_all_items() + self._state = copy.deepcopy(all_items) + + def configure_memory(self, config: MemoryConfig) -> None: + """Update memory configuration. + + Args: + config: New memory configuration + """ + self.memory_config = config + self.memory_manager.config = config + + def get_memory_config(self) -> MemoryConfig: + """Get current memory configuration. + + Returns: + Current memory configuration + """ + return self.memory_config \ No newline at end of file diff --git a/src/strands/agent/memory/lifecycle.py b/src/strands/agent/memory/lifecycle.py new file mode 100644 index 000000000..bfc126d90 --- /dev/null +++ b/src/strands/agent/memory/lifecycle.py @@ -0,0 +1,249 @@ +"""Memory lifecycle management for automatic cleanup and archival.""" + +import time +from typing import Any, Dict, List, Optional, Protocol, Tuple + +from .config import MemoryCategory, MemoryConfig +from .metrics import MemoryMetrics + + +class MemoryItem(Protocol): + """Protocol for memory items that can be managed by lifecycle manager.""" + + category: MemoryCategory + created_at: float + last_accessed: float + access_count: int + size: int + + +class CategorizedMemoryItem: + """A memory item with categorization and lifecycle metadata.""" + + def __init__(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE): + self.key = key + self.value = value + self.category = category + self.created_at = time.time() + self.last_accessed = time.time() + self.access_count = 0 + self.size = self._estimate_size(value) + + def _estimate_size(self, value: Any) -> int: + """Estimate the size of the value.""" + import json + try: + return len(json.dumps(value).encode('utf-8')) + except (TypeError, ValueError): + return len(str(value).encode('utf-8')) + + def access(self) -> None: + """Record an access to this memory item.""" + self.last_accessed = time.time() + self.access_count += 1 + + def age(self) -> float: + """Get the age of this memory item in seconds.""" + return time.time() - self.created_at + + def idle_time(self) -> float: + """Get the idle time since last access in seconds.""" + return time.time() - self.last_accessed + + def should_demote(self, inactive_threshold: float) -> bool: + """Check if item should be demoted from active to cached.""" + return (self.category == MemoryCategory.ACTIVE and + self.idle_time() > inactive_threshold) + + def should_archive(self, archive_threshold: float) -> bool: + """Check if item should be archived.""" + return (self.category == MemoryCategory.CACHED and + self.age() > archive_threshold) + + +class MemoryLifecycleManager: + """Manages the lifecycle of memory items with automatic cleanup and archival.""" + + def __init__(self, config: MemoryConfig): + self.config = config + self.metrics = MemoryMetrics() + self._items: Dict[str, CategorizedMemoryItem] = {} + self._last_cleanup = time.time() + + def add_item(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE) -> None: + """Add or update a memory item.""" + item = CategorizedMemoryItem(key, value, category) + self._items[key] = item + + if category == MemoryCategory.ACTIVE: + self.metrics.record_promotion() + + # Trigger cleanup if needed + if self.config.enable_lifecycle: + self._check_and_cleanup() + + def get_item(self, key: str) -> Optional[Any]: + """Get a memory item by key, updating access statistics.""" + if key in self._items: + item = self._items[key] + item.access() + self.metrics.record_access(hit=True) + + # Promote to active if accessed frequently + if (item.category == MemoryCategory.CACHED and + self._should_promote(item)): + item.category = MemoryCategory.ACTIVE + self.metrics.record_promotion() + + return item.value + else: + self.metrics.record_access(hit=False) + return None + + def remove_item(self, key: str) -> bool: + """Remove a memory item.""" + if key in self._items: + del self._items[key] + return True + return False + + def get_items_by_category(self, category: MemoryCategory) -> Dict[str, Any]: + """Get all items in a specific category.""" + return { + key: item.value + for key, item in self._items.items() + if item.category == category + } + + def get_all_items(self) -> Dict[str, Any]: + """Get all memory items (backward compatibility with AgentState).""" + return {key: item.value for key, item in self._items.items()} + + def cleanup_memory(self, force: bool = False) -> int: + """Perform memory cleanup and return number of items removed.""" + if not self.config.enable_lifecycle and not force: + return 0 + + removed_count = 0 + current_time = time.time() + + # First, demote old active items to cached + for item in list(self._items.values()): + if item.should_demote(self.config.thresholds.cache_ttl): + item.category = MemoryCategory.CACHED + self.metrics.record_demotion() + + # Archive old cached items + if self.config.enable_archival: + for key, item in list(self._items.items()): + if item.should_archive(self.config.thresholds.archive_after): + item.category = MemoryCategory.ARCHIVED + self.metrics.record_archival() + + # Remove items if over memory limits + total_size = sum(item.size for item in self._items.values()) + if force or total_size > self.config.thresholds.total_memory_limit: + removed_count += self._emergency_cleanup() + + self.metrics.record_cleanup() + self._last_cleanup = current_time + self._update_metrics() + + return removed_count + + def _check_and_cleanup(self) -> None: + """Check if cleanup is needed and perform it.""" + total_size = sum(item.size for item in self._items.values()) + utilization = total_size / self.config.thresholds.total_memory_limit + + if utilization >= self.config.thresholds.cleanup_threshold: + self.cleanup_memory() + + # Update metrics periodically + current_time = time.time() + if current_time - self._last_cleanup > 300: # Update every 5 minutes + self._update_metrics() + self._last_cleanup = current_time + + def _should_promote(self, item: CategorizedMemoryItem) -> bool: + """Determine if a cached item should be promoted to active.""" + # Promote if accessed frequently in recent time + recent_accesses = item.access_count + age_hours = item.age() / 3600 + + # Simple heuristic: if accessed more than once per hour on average + return age_hours > 0 and (recent_accesses / age_hours) > 1.0 + + def _emergency_cleanup(self) -> int: + """Perform emergency cleanup when memory limits are exceeded.""" + removed_count = 0 + + # Sort items by priority (LRU-like: least recently used first) + items_by_priority = sorted( + self._items.items(), + key=lambda x: (x[1].category.value, x[1].last_accessed) + ) + + # Calculate how many items to remove + total_items = len(self._items) + target_removal = max(1, int(total_items * self.config.thresholds.emergency_cleanup_ratio)) + + # Remove least important items first + for key, item in items_by_priority[:target_removal]: + # Don't remove active metadata + if item.category != MemoryCategory.METADATA: + del self._items[key] + removed_count += 1 + + return removed_count + + def _update_metrics(self) -> None: + """Update memory usage statistics.""" + from .metrics import MemoryUsageStats + + stats = MemoryUsageStats() + + for item in self._items.values(): + stats.total_size += item.size + stats.total_items += 1 + + if item.category == MemoryCategory.ACTIVE: + stats.active_size += item.size + stats.active_items += 1 + elif item.category == MemoryCategory.CACHED: + stats.cached_size += item.size + stats.cached_items += 1 + elif item.category == MemoryCategory.ARCHIVED: + stats.archived_size += item.size + stats.archived_items += 1 + elif item.category == MemoryCategory.METADATA: + stats.metadata_size += item.size + stats.metadata_items += 1 + + self.metrics.update_stats(stats) + + def get_memory_report(self) -> Dict[str, Any]: + """Get a comprehensive memory usage report.""" + self._update_metrics() + return self.metrics.get_summary() + + def optimize_memory(self) -> Dict[str, int]: + """Perform memory optimization and return optimization results.""" + initial_size = sum(item.size for item in self._items.values()) + initial_count = len(self._items) + + # Perform cleanup + removed_count = self.cleanup_memory(force=True) + + final_size = sum(item.size for item in self._items.values()) + final_count = len(self._items) + + return { + "initial_size": initial_size, + "final_size": final_size, + "size_reduction": initial_size - final_size, + "initial_count": initial_count, + "final_count": final_count, + "items_removed": removed_count, + "size_reduction_pct": ((initial_size - final_size) / initial_size * 100) if initial_size > 0 else 0 + } \ No newline at end of file diff --git a/src/strands/agent/memory/metrics.py b/src/strands/agent/memory/metrics.py new file mode 100644 index 000000000..e1e6b6691 --- /dev/null +++ b/src/strands/agent/memory/metrics.py @@ -0,0 +1,221 @@ +"""Memory usage metrics and monitoring.""" + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from .config import MemoryCategory + + +@dataclass +class MemoryUsageStats: + """Statistics for memory usage tracking.""" + + # Size metrics (estimated tokens/bytes) + total_size: int = 0 + active_size: int = 0 + cached_size: int = 0 + archived_size: int = 0 + metadata_size: int = 0 + + # Count metrics + total_items: int = 0 + active_items: int = 0 + cached_items: int = 0 + archived_items: int = 0 + metadata_items: int = 0 + + # Performance metrics + hit_rate: float = 0.0 # Cache hit rate + miss_rate: float = 0.0 # Cache miss rate + cleanup_count: int = 0 # Number of cleanups performed + last_cleanup: Optional[float] = None # Timestamp of last cleanup + + # Lifecycle metrics + promotions: int = 0 # Items promoted to active + demotions: int = 0 # Items demoted from active + archival_count: int = 0 # Items archived + + def utilization_ratio(self, limit: int) -> float: + """Calculate memory utilization ratio.""" + if limit == 0: + return 0.0 + return min(1.0, self.total_size / limit) + + def category_distribution(self) -> Dict[str, float]: + """Get distribution of memory across categories.""" + if self.total_size == 0: + return {cat.value: 0.0 for cat in MemoryCategory} + + return { + MemoryCategory.ACTIVE.value: self.active_size / self.total_size, + MemoryCategory.CACHED.value: self.cached_size / self.total_size, + MemoryCategory.ARCHIVED.value: self.archived_size / self.total_size, + MemoryCategory.METADATA.value: self.metadata_size / self.total_size, + } + + +@dataclass +class MemoryMetrics: + """Memory metrics collection and analysis.""" + + # Current statistics + stats: MemoryUsageStats = field(default_factory=MemoryUsageStats) + + # Historical data (last N measurements) + history: List[MemoryUsageStats] = field(default_factory=list) + max_history_size: int = 100 + + # Access tracking + access_count: int = 0 + hit_count: int = 0 + miss_count: int = 0 + + # Timing metrics + last_access_time: Optional[float] = None + creation_time: float = field(default_factory=time.time) + + def record_access(self, hit: bool = True) -> None: + """Record a memory access (hit or miss).""" + self.access_count += 1 + self.last_access_time = time.time() + + if hit: + self.hit_count += 1 + else: + self.miss_count += 1 + + # Update hit/miss rates + if self.access_count > 0: + self.stats.hit_rate = self.hit_count / self.access_count + self.stats.miss_rate = self.miss_count / self.access_count + + def record_cleanup(self) -> None: + """Record a memory cleanup operation.""" + self.stats.cleanup_count += 1 + self.stats.last_cleanup = time.time() + + def record_promotion(self) -> None: + """Record a memory item promotion.""" + self.stats.promotions += 1 + + def record_demotion(self) -> None: + """Record a memory item demotion.""" + self.stats.demotions += 1 + + def record_archival(self) -> None: + """Record a memory item archival.""" + self.stats.archival_count += 1 + + def update_stats(self, new_stats: MemoryUsageStats) -> None: + """Update current statistics and save to history.""" + # Save current stats to history + if len(self.history) >= self.max_history_size: + self.history.pop(0) # Remove oldest entry + + self.history.append(self.stats) + + # Preserve accumulated metrics from current stats + new_stats.hit_rate = self.stats.hit_rate + new_stats.miss_rate = self.stats.miss_rate + new_stats.cleanup_count = self.stats.cleanup_count + new_stats.promotions = self.stats.promotions + new_stats.demotions = self.stats.demotions + new_stats.archival_count = self.stats.archival_count + new_stats.last_cleanup = self.stats.last_cleanup + + self.stats = new_stats + + def estimate_item_size(self, value: Any) -> int: + """Estimate the size of a memory item in bytes/tokens.""" + try: + # Use JSON serialization size as rough estimate + json_str = json.dumps(value) + # Rough token estimate: ~4 characters per token + return len(json_str.encode('utf-8')) + except (TypeError, ValueError): + # Fallback for non-serializable objects + return len(str(value).encode('utf-8')) + + def get_trend_analysis(self, window: int = 10) -> Dict[str, float]: + """Analyze trends in memory usage over the last N measurements.""" + if len(self.history) < 1: + return {"trend": 0.0, "volatility": 0.0} + + # Get all stats (history + current), excluding initial empty stats + all_history = self.history + [self.stats] + # Skip the first item if it's the initial empty stats (total_size == 0) + if all_history and all_history[0].total_size == 0 and len(all_history) > 1: + all_history = all_history[1:] + + # Apply window to the complete set + all_stats = all_history[-window:] if len(all_history) >= window else all_history + + # Calculate trend (simple linear regression slope) + if len(all_stats) < 2: + return {"trend": 0.0, "volatility": 0.0} + + sizes = [stats.total_size for stats in all_stats] + n = len(sizes) + x_mean = (n - 1) / 2 + y_mean = sum(sizes) / n + + numerator = sum((i - x_mean) * (sizes[i] - y_mean) for i in range(n)) + denominator = sum((i - x_mean) ** 2 for i in range(n)) + + trend = numerator / denominator if denominator != 0 else 0.0 + + # Calculate volatility (standard deviation of sizes) + variance = sum((size - y_mean) ** 2 for size in sizes) / n + volatility = variance ** 0.5 + + return { + "trend": trend, + "volatility": volatility, + "avg_size": y_mean, + "min_size": min(sizes), + "max_size": max(sizes) + } + + def should_cleanup(self, threshold: float, limit: int) -> bool: + """Determine if memory cleanup should be triggered.""" + utilization = self.stats.utilization_ratio(limit) + return utilization >= threshold + + def get_summary(self) -> Dict[str, Any]: + """Get a comprehensive summary of memory metrics.""" + trend_analysis = self.get_trend_analysis() + + return { + "current_stats": { + "total_size": self.stats.total_size, + "total_items": self.stats.total_items, + "distribution": self.stats.category_distribution(), + "hit_rate": self.stats.hit_rate, + "cleanup_count": self.stats.cleanup_count + }, + "performance": { + "access_count": self.access_count, + "hit_rate": self.hit_count / self.access_count if self.access_count > 0 else 0.0, + "miss_rate": self.miss_count / self.access_count if self.access_count > 0 else 0.0, + "avg_response_time": self._calculate_avg_response_time() + }, + "trends": trend_analysis, + "lifecycle": { + "promotions": self.stats.promotions, + "demotions": self.stats.demotions, + "archival_count": self.stats.archival_count + }, + "timestamps": { + "creation_time": self.creation_time, + "last_access": self.last_access_time, + "last_cleanup": self.stats.last_cleanup + } + } + + def _calculate_avg_response_time(self) -> Optional[float]: + """Calculate average response time based on access patterns.""" + # Placeholder for future implementation + # Could track timing data for memory operations + return None \ No newline at end of file diff --git a/tests/strands/agent/memory/__init__.py b/tests/strands/agent/memory/__init__.py new file mode 100644 index 000000000..3f50f6ae7 --- /dev/null +++ b/tests/strands/agent/memory/__init__.py @@ -0,0 +1 @@ +"""Tests for memory management system.""" \ No newline at end of file diff --git a/tests/strands/agent/memory/test_config.py b/tests/strands/agent/memory/test_config.py new file mode 100644 index 000000000..a028ba0aa --- /dev/null +++ b/tests/strands/agent/memory/test_config.py @@ -0,0 +1,137 @@ +"""Tests for memory management configuration.""" + +import pytest + +from strands.agent.memory.config import MemoryCategory, MemoryConfig, MemoryThresholds + + +def test_memory_category_enum(): + """Test MemoryCategory enum values.""" + assert MemoryCategory.ACTIVE.value == "active" + assert MemoryCategory.CACHED.value == "cached" + assert MemoryCategory.ARCHIVED.value == "archived" + assert MemoryCategory.METADATA.value == "metadata" + + +def test_memory_thresholds_defaults(): + """Test MemoryThresholds default values.""" + thresholds = MemoryThresholds() + + assert thresholds.active_memory_limit == 8192 + assert thresholds.cached_memory_limit == 32768 + assert thresholds.total_memory_limit == 131072 + assert thresholds.cleanup_threshold == 0.8 + assert thresholds.emergency_threshold == 0.95 + assert thresholds.cache_ttl == 3600 + assert thresholds.archive_after == 86400 + assert thresholds.cleanup_ratio == 0.3 + assert thresholds.emergency_cleanup_ratio == 0.5 + + +def test_memory_thresholds_custom_values(): + """Test MemoryThresholds with custom values.""" + thresholds = MemoryThresholds( + active_memory_limit=4096, + cached_memory_limit=16384, + total_memory_limit=65536, + cleanup_threshold=0.7, + emergency_threshold=0.9 + ) + + assert thresholds.active_memory_limit == 4096 + assert thresholds.cached_memory_limit == 16384 + assert thresholds.total_memory_limit == 65536 + assert thresholds.cleanup_threshold == 0.7 + assert thresholds.emergency_threshold == 0.9 + + +def test_memory_config_defaults(): + """Test MemoryConfig default values.""" + config = MemoryConfig() + + assert config.enable_categorization is True + assert config.enable_lifecycle is True + assert config.enable_metrics is True + assert config.enable_archival is True + assert config.cleanup_strategy == "lru" + assert config.strict_validation is True + assert config.thresholds is not None + assert isinstance(config.thresholds, MemoryThresholds) + + +def test_memory_config_with_custom_thresholds(): + """Test MemoryConfig with custom thresholds.""" + thresholds = MemoryThresholds(active_memory_limit=2048) + config = MemoryConfig(thresholds=thresholds) + + assert config.thresholds.active_memory_limit == 2048 + assert config.thresholds.cached_memory_limit == 32768 # Default value + + +def test_memory_config_conservative(): + """Test conservative memory configuration.""" + config = MemoryConfig.conservative() + + assert config.thresholds.active_memory_limit == 4096 + assert config.thresholds.cached_memory_limit == 16384 + assert config.thresholds.total_memory_limit == 65536 + assert config.thresholds.cleanup_threshold == 0.7 + assert config.thresholds.cleanup_ratio == 0.4 + + +def test_memory_config_aggressive(): + """Test aggressive memory configuration.""" + config = MemoryConfig.aggressive() + + assert config.thresholds.active_memory_limit == 16384 + assert config.thresholds.cached_memory_limit == 65536 + assert config.thresholds.total_memory_limit == 262144 + assert config.thresholds.cleanup_threshold == 0.9 + assert config.thresholds.cleanup_ratio == 0.2 + + +def test_memory_config_minimal(): + """Test minimal memory configuration.""" + config = MemoryConfig.minimal() + + assert config.enable_lifecycle is False + assert config.enable_metrics is False + assert config.enable_archival is False + assert config.enable_categorization is True # Still enabled by default + + assert config.thresholds.active_memory_limit == 2048 + assert config.thresholds.cached_memory_limit == 8192 + assert config.thresholds.total_memory_limit == 32768 + + +def test_memory_config_custom_features(): + """Test MemoryConfig with custom feature toggles.""" + config = MemoryConfig( + enable_categorization=False, + enable_lifecycle=False, + enable_metrics=False, + enable_archival=False, + cleanup_strategy="fifo", + strict_validation=False + ) + + assert config.enable_categorization is False + assert config.enable_lifecycle is False + assert config.enable_metrics is False + assert config.enable_archival is False + assert config.cleanup_strategy == "fifo" + assert config.strict_validation is False + + +def test_memory_config_post_init(): + """Test MemoryConfig __post_init__ method.""" + # Test with None thresholds (should create default) + config = MemoryConfig(thresholds=None) + assert config.thresholds is not None + assert isinstance(config.thresholds, MemoryThresholds) + + # Test with provided thresholds (should keep them) + custom_thresholds = MemoryThresholds(active_memory_limit=1024) + config = MemoryConfig(thresholds=custom_thresholds) + assert config.thresholds is custom_thresholds + assert config.thresholds.active_memory_limit == 1024 \ No newline at end of file diff --git a/tests/strands/agent/memory/test_enhanced_state.py b/tests/strands/agent/memory/test_enhanced_state.py new file mode 100644 index 000000000..af500213b --- /dev/null +++ b/tests/strands/agent/memory/test_enhanced_state.py @@ -0,0 +1,328 @@ +"""Tests for enhanced agent state with memory management.""" + +import pytest + +from strands.agent.memory.config import MemoryCategory, MemoryConfig +from strands.agent.memory.enhanced_state import EnhancedAgentState + + +def test_enhanced_agent_state_initialization(): + """Test EnhancedAgentState initialization.""" + state = EnhancedAgentState() + + assert state.memory_config is not None + assert state.memory_manager is not None + assert isinstance(state.memory_config, MemoryConfig) + + +def test_enhanced_agent_state_with_initial_state(): + """Test EnhancedAgentState initialization with initial state.""" + initial_state = {"key1": "value1", "key2": "value2"} + state = EnhancedAgentState(initial_state=initial_state) + + # Should be available through both interfaces + assert state.get("key1") == "value1" + assert state.get("key2") == "value2" + + # Should be available through parent interface (backward compatibility) + parent_state = super(EnhancedAgentState, state).get() + assert parent_state["key1"] == "value1" + assert parent_state["key2"] == "value2" + + +def test_enhanced_agent_state_with_memory_config(): + """Test EnhancedAgentState with custom memory configuration.""" + config = MemoryConfig.conservative() + state = EnhancedAgentState(memory_config=config) + + assert state.memory_config is config + assert state.memory_manager.config is config + + +def test_enhanced_agent_state_set_get_basic(): + """Test basic set and get operations.""" + state = EnhancedAgentState() + + state.set("test_key", "test_value") + assert state.get("test_key") == "test_value" + + +def test_enhanced_agent_state_set_with_category(): + """Test set operation with memory category.""" + state = EnhancedAgentState() + + state.set("active_key", "active_value", MemoryCategory.ACTIVE) + state.set("cached_key", "cached_value", MemoryCategory.CACHED) + state.set("metadata_key", "metadata_value", MemoryCategory.METADATA) + + assert state.get("active_key") == "active_value" + assert state.get("cached_key") == "cached_value" + assert state.get("metadata_key") == "metadata_value" + + +def test_enhanced_agent_state_get_by_category(): + """Test retrieving items by category.""" + state = EnhancedAgentState() + + state.set("active1", "value1", MemoryCategory.ACTIVE) + state.set("active2", "value2", MemoryCategory.ACTIVE) + state.set("cached1", "value3", MemoryCategory.CACHED) + + active_items = state.get_by_category(MemoryCategory.ACTIVE) + cached_items = state.get_by_category(MemoryCategory.CACHED) + archived_items = state.get_by_category(MemoryCategory.ARCHIVED) + + assert len(active_items) == 2 + assert len(cached_items) == 1 + assert len(archived_items) == 0 + assert active_items["active1"] == "value1" + assert cached_items["cached1"] == "value3" + + +def test_enhanced_agent_state_get_by_category_disabled(): + """Test get_by_category when categorization is disabled.""" + config = MemoryConfig(enable_categorization=False) + state = EnhancedAgentState(memory_config=config) + + state.set("key1", "value1") + state.set("key2", "value2") + + # Should return all items regardless of category when disabled + items = state.get_by_category(MemoryCategory.ACTIVE) + assert len(items) == 2 + assert items["key1"] == "value1" + assert items["key2"] == "value2" + + +def test_enhanced_agent_state_get_entire_state(): + """Test getting entire state.""" + state = EnhancedAgentState() + + state.set("key1", "value1", MemoryCategory.ACTIVE) + state.set("key2", "value2", MemoryCategory.CACHED) + + # Get entire state + all_items = state.get() + assert len(all_items) == 2 + assert all_items["key1"] == "value1" + assert all_items["key2"] == "value2" + + +def test_enhanced_agent_state_get_filtered_by_category(): + """Test getting state filtered by category.""" + state = EnhancedAgentState() + + state.set("active1", "value1", MemoryCategory.ACTIVE) + state.set("cached1", "value2", MemoryCategory.CACHED) + + # Get only active items + active_items = state.get(category=MemoryCategory.ACTIVE) + assert len(active_items) == 1 + assert active_items["active1"] == "value1" + + # Get only cached items + cached_items = state.get(category=MemoryCategory.CACHED) + assert len(cached_items) == 1 + assert cached_items["cached1"] == "value2" + + +def test_enhanced_agent_state_delete(): + """Test deleting items.""" + state = EnhancedAgentState() + + state.set("key1", "value1") + state.set("key2", "value2") + + state.delete("key1") + + assert state.get("key1") is None + assert state.get("key2") == "value2" + + # Should also be deleted from parent state + parent_state = super(EnhancedAgentState, state).get() + assert "key1" not in parent_state + assert parent_state["key2"] == "value2" + + +def test_enhanced_agent_state_convenience_methods(): + """Test convenience methods for memory categories.""" + state = EnhancedAgentState() + + # Test set_metadata + state.set_metadata("meta_key", "meta_value") + metadata_items = state.get_by_category(MemoryCategory.METADATA) + assert metadata_items["meta_key"] == "meta_value" + + # Test get_active_memory + state.set("active_key", "active_value", MemoryCategory.ACTIVE) + active_items = state.get_active_memory() + assert active_items["active_key"] == "active_value" + + # Test get_cached_memory + state.set("cached_key", "cached_value", MemoryCategory.CACHED) + cached_items = state.get_cached_memory() + assert cached_items["cached_key"] == "cached_value" + + +def test_enhanced_agent_state_cleanup_memory(): + """Test memory cleanup functionality.""" + state = EnhancedAgentState() + + state.set("key1", "value1") + state.set("key2", "value2") + + # Test cleanup (might not remove anything without time passage) + removed_count = state.cleanup_memory() + assert removed_count >= 0 + + # Test forced cleanup + removed_count = state.cleanup_memory(force=True) + assert removed_count >= 0 + + +def test_enhanced_agent_state_cleanup_memory_disabled(): + """Test cleanup when lifecycle management is disabled.""" + config = MemoryConfig(enable_lifecycle=False) + state = EnhancedAgentState(memory_config=config) + + state.set("key1", "value1") + + # Should return 0 when lifecycle disabled + removed_count = state.cleanup_memory() + assert removed_count == 0 + + # Should work with force=True + removed_count = state.cleanup_memory(force=True) + assert removed_count >= 0 + + +def test_enhanced_agent_state_get_memory_stats(): + """Test getting memory statistics.""" + state = EnhancedAgentState() + + state.set("key1", "value1", MemoryCategory.ACTIVE) + state.set("key2", "value2", MemoryCategory.CACHED) + + stats = state.get_memory_stats() + + assert "current_stats" in stats or "total_items" in stats + + # Test with metrics disabled + config = MemoryConfig(enable_metrics=False) + state_no_metrics = EnhancedAgentState(memory_config=config) + stats = state_no_metrics.get_memory_stats() + + assert stats["metrics_enabled"] is False + assert "total_items" in stats + + +def test_enhanced_agent_state_optimize_memory(): + """Test memory optimization.""" + state = EnhancedAgentState() + + state.set("key1", "value1") + state.set("key2", "value2") + + optimization_results = state.optimize_memory() + + # Should return optimization results or skip info + assert isinstance(optimization_results, dict) + + # Test with lifecycle disabled + config = MemoryConfig(enable_lifecycle=False) + state_no_lifecycle = EnhancedAgentState(memory_config=config) + results = state_no_lifecycle.optimize_memory() + + assert results.get("optimization_skipped") is True + assert results.get("reason") == "lifecycle_disabled" + + +def test_enhanced_agent_state_configure_memory(): + """Test memory configuration updates.""" + state = EnhancedAgentState() + original_config = state.get_memory_config() + + new_config = MemoryConfig.aggressive() + state.configure_memory(new_config) + + assert state.get_memory_config() is new_config + assert state.memory_manager.config is new_config + assert state.get_memory_config() is not original_config + + +def test_enhanced_agent_state_backward_compatibility(): + """Test backward compatibility with base AgentState interface.""" + state = EnhancedAgentState() + + # All base AgentState operations should work + state.set("key1", "value1") + assert state.get("key1") == "value1" + + # Get entire state should work + all_state = state.get() + assert all_state["key1"] == "value1" + + # Delete should work + state.delete("key1") + assert state.get("key1") is None + + # Validation should still work + with pytest.raises(ValueError, match="Key cannot be None"): + state.set(None, "value") + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("key", lambda x: x) + + +def test_enhanced_agent_state_with_categorization_disabled(): + """Test behavior when categorization is disabled.""" + config = MemoryConfig(enable_categorization=False) + state = EnhancedAgentState(memory_config=config) + + # Set operations should still work + state.set("key1", "value1") + state.set("key2", "value2", MemoryCategory.CACHED) # Category ignored + + # Get should work normally + assert state.get("key1") == "value1" + assert state.get("key2") == "value2" + + # Should use parent implementation when categorization disabled + all_items = state.get() + assert all_items["key1"] == "value1" + assert all_items["key2"] == "value2" + + +def test_enhanced_agent_state_json_validation(): + """Test that JSON validation is maintained.""" + state = EnhancedAgentState() + + # Valid JSON types should work + state.set("string", "test") + state.set("int", 42) + state.set("bool", True) + state.set("list", [1, 2, 3]) + state.set("dict", {"nested": "value"}) + state.set("null", None) + + # Invalid JSON types should raise ValueError + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("function", lambda x: x) + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("object", object()) + + +def test_enhanced_agent_state_key_validation(): + """Test that key validation is maintained.""" + state = EnhancedAgentState() + + # Invalid keys should raise ValueError + with pytest.raises(ValueError, match="Key cannot be None"): + state.set(None, "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + state.set("", "value") + + with pytest.raises(ValueError, match="Key must be a string"): + state.set(123, "value") \ No newline at end of file diff --git a/tests/strands/agent/memory/test_lifecycle.py b/tests/strands/agent/memory/test_lifecycle.py new file mode 100644 index 000000000..15895e7d4 --- /dev/null +++ b/tests/strands/agent/memory/test_lifecycle.py @@ -0,0 +1,351 @@ +"""Tests for memory lifecycle management.""" + +import time +from unittest.mock import patch + +import pytest + +from strands.agent.memory.config import MemoryCategory, MemoryConfig, MemoryThresholds +from strands.agent.memory.lifecycle import CategorizedMemoryItem, MemoryLifecycleManager + + +def test_categorized_memory_item_creation(): + """Test CategorizedMemoryItem creation and properties.""" + item = CategorizedMemoryItem("test_key", "test_value") + + assert item.key == "test_key" + assert item.value == "test_value" + assert item.category == MemoryCategory.ACTIVE + assert item.access_count == 0 + assert item.size > 0 + assert item.created_at <= time.time() + assert item.last_accessed <= time.time() + + +def test_categorized_memory_item_with_category(): + """Test CategorizedMemoryItem with specific category.""" + item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.CACHED) + + assert item.category == MemoryCategory.CACHED + + +def test_categorized_memory_item_access(): + """Test memory item access tracking.""" + item = CategorizedMemoryItem("test_key", "test_value") + initial_access_time = item.last_accessed + initial_access_count = item.access_count + + time.sleep(0.01) # Small delay to ensure time difference + item.access() + + assert item.access_count == initial_access_count + 1 + assert item.last_accessed > initial_access_time + + +def test_categorized_memory_item_age(): + """Test memory item age calculation.""" + with patch('time.time') as mock_time: + mock_time.return_value = 1000.0 + item = CategorizedMemoryItem("test_key", "test_value") + + mock_time.return_value = 1010.0 + assert item.age() == 10.0 + + +def test_categorized_memory_item_idle_time(): + """Test memory item idle time calculation.""" + with patch('time.time') as mock_time: + mock_time.return_value = 1000.0 + item = CategorizedMemoryItem("test_key", "test_value") + + mock_time.return_value = 1005.0 + item.access() + + mock_time.return_value = 1015.0 + assert item.idle_time() == 10.0 + + +def test_categorized_memory_item_should_demote(): + """Test demotion logic for memory items.""" + item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.ACTIVE) + + with patch.object(item, 'idle_time', return_value=3700): # > 1 hour + assert item.should_demote(3600) is True + + with patch.object(item, 'idle_time', return_value=1800): # < 1 hour + assert item.should_demote(3600) is False + + # Non-active items should not be demoted + item.category = MemoryCategory.CACHED + with patch.object(item, 'idle_time', return_value=3700): + assert item.should_demote(3600) is False + + +def test_categorized_memory_item_should_archive(): + """Test archival logic for memory items.""" + item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.CACHED) + + with patch.object(item, 'age', return_value=90000): # > 24 hours + assert item.should_archive(86400) is True + + with patch.object(item, 'age', return_value=43200): # < 24 hours + assert item.should_archive(86400) is False + + # Non-cached items should not be archived + item.category = MemoryCategory.ACTIVE + with patch.object(item, 'age', return_value=90000): + assert item.should_archive(86400) is False + + +def test_categorized_memory_item_size_estimation(): + """Test size estimation for different value types.""" + # Test with different value types + string_item = CategorizedMemoryItem("key", "hello") + dict_item = CategorizedMemoryItem("key", {"nested": "data"}) + list_item = CategorizedMemoryItem("key", [1, 2, 3, 4, 5]) + + assert string_item.size > 0 + assert dict_item.size > 0 + assert list_item.size > 0 + + # Larger objects should have larger estimated sizes + large_string = "x" * 1000 + large_item = CategorizedMemoryItem("key", large_string) + assert large_item.size > string_item.size + + +def test_memory_lifecycle_manager_initialization(): + """Test MemoryLifecycleManager initialization.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + assert manager.config is config + assert manager._items == {} + assert manager._last_cleanup <= time.time() + + +def test_memory_lifecycle_manager_add_item(): + """Test adding items to memory manager.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1") + manager.add_item("key2", "value2", MemoryCategory.CACHED) + + assert "key1" in manager._items + assert "key2" in manager._items + assert manager._items["key1"].category == MemoryCategory.ACTIVE + assert manager._items["key2"].category == MemoryCategory.CACHED + assert manager._items["key1"].value == "value1" + assert manager._items["key2"].value == "value2" + + +def test_memory_lifecycle_manager_get_item(): + """Test retrieving items from memory manager.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("existing_key", "existing_value") + + # Test existing key + value = manager.get_item("existing_key") + assert value == "existing_value" + + # Test non-existing key + value = manager.get_item("non_existing_key") + assert value is None + + # Check that access was recorded + item = manager._items["existing_key"] + assert item.access_count == 1 + + +def test_memory_lifecycle_manager_remove_item(): + """Test removing items from memory manager.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1") + + # Remove existing item + assert manager.remove_item("key1") is True + assert "key1" not in manager._items + + # Remove non-existing item + assert manager.remove_item("non_existing") is False + + +def test_memory_lifecycle_manager_get_items_by_category(): + """Test retrieving items by category.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("active1", "value1", MemoryCategory.ACTIVE) + manager.add_item("active2", "value2", MemoryCategory.ACTIVE) + manager.add_item("cached1", "value3", MemoryCategory.CACHED) + + active_items = manager.get_items_by_category(MemoryCategory.ACTIVE) + cached_items = manager.get_items_by_category(MemoryCategory.CACHED) + archived_items = manager.get_items_by_category(MemoryCategory.ARCHIVED) + + assert len(active_items) == 2 + assert len(cached_items) == 1 + assert len(archived_items) == 0 + assert active_items["active1"] == "value1" + assert cached_items["cached1"] == "value3" + + +def test_memory_lifecycle_manager_get_all_items(): + """Test retrieving all items.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1", MemoryCategory.ACTIVE) + manager.add_item("key2", "value2", MemoryCategory.CACHED) + + all_items = manager.get_all_items() + + assert len(all_items) == 2 + assert all_items["key1"] == "value1" + assert all_items["key2"] == "value2" + + +def test_memory_lifecycle_manager_promotion(): + """Test automatic promotion of frequently accessed items.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + # Add cached item + manager.add_item("key1", "value1", MemoryCategory.CACHED) + item = manager._items["key1"] + + # Mock frequent access pattern + with patch.object(manager, '_should_promote', return_value=True): + # Access should trigger promotion + manager.get_item("key1") + + # Check that item was promoted + assert manager._items["key1"].category == MemoryCategory.ACTIVE + + +def test_memory_lifecycle_manager_should_promote(): + """Test promotion logic.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + # Create item with access pattern + item = CategorizedMemoryItem("key", "value", MemoryCategory.CACHED) + + # Test promotion logic + with patch.object(item, 'age', return_value=7200): # 2 hours + with patch.object(item, 'access_count', 3): # 3 accesses in 2 hours > 1/hour + assert manager._should_promote(item) is True + + with patch.object(item, 'access_count', 1): # 1 access in 2 hours < 1/hour + assert manager._should_promote(item) is False + + # Test with new item (age = 0) + with patch.object(item, 'age', return_value=0): + assert manager._should_promote(item) is False + + +def test_memory_lifecycle_manager_cleanup_memory(): + """Test memory cleanup functionality.""" + thresholds = MemoryThresholds( + cache_ttl=1800, # 30 minutes + archive_after=3600 # 1 hour + ) + config = MemoryConfig(thresholds=thresholds, enable_lifecycle=True) + manager = MemoryLifecycleManager(config) + + # Add items with different ages + manager.add_item("active_old", "value1", MemoryCategory.ACTIVE) + manager.add_item("cached_old", "value2", MemoryCategory.CACHED) + + # Mock ages for demotion/archival + active_item = manager._items["active_old"] + cached_item = manager._items["cached_old"] + + with patch.object(active_item, 'should_demote', return_value=True): + with patch.object(cached_item, 'should_archive', return_value=True): + # Run cleanup + manager.cleanup_memory() + + # Check changes + assert active_item.category == MemoryCategory.CACHED + assert cached_item.category == MemoryCategory.ARCHIVED + + +def test_memory_lifecycle_manager_emergency_cleanup(): + """Test emergency cleanup when memory limits exceeded.""" + thresholds = MemoryThresholds(total_memory_limit=100) # Very small limit + config = MemoryConfig(thresholds=thresholds, enable_lifecycle=True) + manager = MemoryLifecycleManager(config) + + # Add items that exceed limit + large_value = "x" * 50 + manager.add_item("key1", large_value, MemoryCategory.CACHED) + manager.add_item("key2", large_value, MemoryCategory.CACHED) + manager.add_item("key3", large_value, MemoryCategory.ACTIVE) + + initial_count = len(manager._items) + + # Force cleanup + removed_count = manager.cleanup_memory(force=True) + + # Should have removed some items + assert removed_count > 0 + assert len(manager._items) < initial_count + + +def test_memory_lifecycle_manager_disabled_lifecycle(): + """Test behavior when lifecycle management is disabled.""" + config = MemoryConfig(enable_lifecycle=False) + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1") + + # Cleanup should do nothing when lifecycle disabled + removed_count = manager.cleanup_memory(force=False) + assert removed_count == 0 + + # But should work with force=True + removed_count = manager.cleanup_memory(force=True) + assert removed_count >= 0 # May or may not remove items + + +def test_memory_lifecycle_manager_get_memory_report(): + """Test memory usage report generation.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + manager.add_item("key1", "value1", MemoryCategory.ACTIVE) + manager.add_item("key2", "value2", MemoryCategory.CACHED) + + report = manager.get_memory_report() + + assert "current_stats" in report + assert "performance" in report + assert "trends" in report + assert "lifecycle" in report + assert "timestamps" in report + + +def test_memory_lifecycle_manager_optimize_memory(): + """Test memory optimization.""" + config = MemoryConfig() + manager = MemoryLifecycleManager(config) + + # Add some items + manager.add_item("key1", "value1") + manager.add_item("key2", "value2") + + optimization_results = manager.optimize_memory() + + assert "initial_size" in optimization_results + assert "final_size" in optimization_results + assert "size_reduction" in optimization_results + assert "initial_count" in optimization_results + assert "final_count" in optimization_results + assert "items_removed" in optimization_results + assert "size_reduction_pct" in optimization_results \ No newline at end of file diff --git a/tests/strands/agent/memory/test_metrics.py b/tests/strands/agent/memory/test_metrics.py new file mode 100644 index 000000000..94f5c1911 --- /dev/null +++ b/tests/strands/agent/memory/test_metrics.py @@ -0,0 +1,301 @@ +"""Tests for memory metrics and monitoring.""" + +import json +import time +from unittest.mock import patch + +import pytest + +from strands.agent.memory.config import MemoryCategory +from strands.agent.memory.metrics import MemoryMetrics, MemoryUsageStats + + +def test_memory_usage_stats_defaults(): + """Test MemoryUsageStats default values.""" + stats = MemoryUsageStats() + + assert stats.total_size == 0 + assert stats.active_size == 0 + assert stats.cached_size == 0 + assert stats.archived_size == 0 + assert stats.metadata_size == 0 + assert stats.total_items == 0 + assert stats.hit_rate == 0.0 + assert stats.cleanup_count == 0 + assert stats.promotions == 0 + + +def test_memory_usage_stats_utilization_ratio(): + """Test utilization ratio calculation.""" + stats = MemoryUsageStats() + stats.total_size = 5000 + + # Normal case + assert stats.utilization_ratio(10000) == 0.5 + + # Over limit case + assert stats.utilization_ratio(2500) == 1.0 + + # Zero limit case + assert stats.utilization_ratio(0) == 0.0 + + +def test_memory_usage_stats_category_distribution(): + """Test category distribution calculation.""" + stats = MemoryUsageStats() + stats.total_size = 1000 + stats.active_size = 400 + stats.cached_size = 300 + stats.archived_size = 200 + stats.metadata_size = 100 + + distribution = stats.category_distribution() + + assert distribution[MemoryCategory.ACTIVE.value] == 0.4 + assert distribution[MemoryCategory.CACHED.value] == 0.3 + assert distribution[MemoryCategory.ARCHIVED.value] == 0.2 + assert distribution[MemoryCategory.METADATA.value] == 0.1 + + +def test_memory_usage_stats_category_distribution_empty(): + """Test category distribution with zero total size.""" + stats = MemoryUsageStats() + # total_size = 0 by default + + distribution = stats.category_distribution() + + for category in MemoryCategory: + assert distribution[category.value] == 0.0 + + +def test_memory_metrics_initialization(): + """Test MemoryMetrics initialization.""" + metrics = MemoryMetrics() + + assert isinstance(metrics.stats, MemoryUsageStats) + assert metrics.history == [] + assert metrics.max_history_size == 100 + assert metrics.access_count == 0 + assert metrics.hit_count == 0 + assert metrics.miss_count == 0 + assert metrics.last_access_time is None + assert metrics.creation_time <= time.time() + + +def test_memory_metrics_record_access(): + """Test recording memory access (hits and misses).""" + metrics = MemoryMetrics() + + # Record hits + metrics.record_access(hit=True) + metrics.record_access(hit=True) + + assert metrics.access_count == 2 + assert metrics.hit_count == 2 + assert metrics.miss_count == 0 + assert metrics.stats.hit_rate == 1.0 + assert metrics.stats.miss_rate == 0.0 + assert metrics.last_access_time is not None + + # Record miss + metrics.record_access(hit=False) + + assert metrics.access_count == 3 + assert metrics.hit_count == 2 + assert metrics.miss_count == 1 + assert metrics.stats.hit_rate == 2/3 + assert metrics.stats.miss_rate == 1/3 + + +def test_memory_metrics_record_operations(): + """Test recording various memory operations.""" + metrics = MemoryMetrics() + + # Test cleanup recording + metrics.record_cleanup() + assert metrics.stats.cleanup_count == 1 + assert metrics.stats.last_cleanup is not None + + # Test promotion recording + metrics.record_promotion() + assert metrics.stats.promotions == 1 + + # Test demotion recording + metrics.record_demotion() + assert metrics.stats.demotions == 1 + + # Test archival recording + metrics.record_archival() + assert metrics.stats.archival_count == 1 + + +def test_memory_metrics_update_stats(): + """Test updating statistics with history tracking.""" + metrics = MemoryMetrics() + + # Create new stats + new_stats = MemoryUsageStats() + new_stats.total_size = 1000 + new_stats.total_items = 10 + + # Update stats + metrics.update_stats(new_stats) + + # Check that old stats were saved to history + assert len(metrics.history) == 1 + assert metrics.history[0].total_size == 0 # Old default stats + + # Check that new stats are current + assert metrics.stats.total_size == 1000 + assert metrics.stats.total_items == 10 + + +def test_memory_metrics_history_limit(): + """Test that history is limited to max_history_size.""" + metrics = MemoryMetrics() + metrics.max_history_size = 3 + + # Add more stats than the limit + for i in range(5): + new_stats = MemoryUsageStats() + new_stats.total_size = i * 100 + metrics.update_stats(new_stats) + + # Check that history is limited + assert len(metrics.history) == 3 + + # Check that oldest entries were removed (should have sizes 100, 200, 300) + assert metrics.history[0].total_size == 100 + assert metrics.history[1].total_size == 200 + assert metrics.history[2].total_size == 300 + + +def test_memory_metrics_estimate_item_size(): + """Test item size estimation.""" + metrics = MemoryMetrics() + + # Test with JSON-serializable objects + assert metrics.estimate_item_size("hello") > 0 + assert metrics.estimate_item_size({"key": "value"}) > 0 + assert metrics.estimate_item_size([1, 2, 3]) > 0 + + # Test that larger objects have larger estimated sizes + small_obj = "hi" + large_obj = "hello world" * 100 + assert metrics.estimate_item_size(large_obj) > metrics.estimate_item_size(small_obj) + + # Test with non-serializable object + class CustomObject: + def __str__(self): + return "custom object" + + custom_obj = CustomObject() + size = metrics.estimate_item_size(custom_obj) + assert size > 0 # Should fallback to string representation + + +def test_memory_metrics_trend_analysis(): + """Test trend analysis functionality.""" + metrics = MemoryMetrics() + + # Test with insufficient history + trend = metrics.get_trend_analysis() + assert trend["trend"] == 0.0 + assert trend["volatility"] == 0.0 + + # Add some history with increasing sizes + sizes = [100, 200, 300, 400, 500] + for size in sizes: + stats = MemoryUsageStats() + stats.total_size = size + metrics.update_stats(stats) + + trend = metrics.get_trend_analysis() + assert trend["trend"] > 0 # Should show upward trend + assert trend["avg_size"] == 300 # Average of 100, 200, 300, 400, 500 + assert trend["min_size"] == 100 + assert trend["max_size"] == 500 + assert trend["volatility"] > 0 + + +def test_memory_metrics_trend_analysis_with_window(): + """Test trend analysis with custom window size.""" + metrics = MemoryMetrics() + + # Add history + sizes = [100, 200, 300, 400, 500, 600] + for size in sizes: + stats = MemoryUsageStats() + stats.total_size = size + metrics.update_stats(stats) + + # Analyze with smaller window + trend = metrics.get_trend_analysis(window=3) + # Should only consider last 3 values: 400, 500, 600 + assert trend["avg_size"] == 500 + assert trend["min_size"] == 400 + assert trend["max_size"] == 600 + + +def test_memory_metrics_should_cleanup(): + """Test cleanup decision logic.""" + metrics = MemoryMetrics() + + # Setup stats + metrics.stats.total_size = 8000 + + # Should cleanup when over threshold + assert metrics.should_cleanup(threshold=0.8, limit=10000) is True + + # Should not cleanup when under threshold + assert metrics.should_cleanup(threshold=0.9, limit=10000) is False + + # Edge case: zero limit + assert metrics.should_cleanup(threshold=0.5, limit=0) is False + + +def test_memory_metrics_get_summary(): + """Test comprehensive summary generation.""" + metrics = MemoryMetrics() + + # Setup some data + metrics.record_access(hit=True) + metrics.record_access(hit=False) + metrics.record_cleanup() + metrics.record_promotion() + + # Add some history for trend analysis + for i in range(3): + stats = MemoryUsageStats() + stats.total_size = (i + 1) * 100 + metrics.update_stats(stats) + + summary = metrics.get_summary() + + # Check structure + assert "current_stats" in summary + assert "performance" in summary + assert "trends" in summary + assert "lifecycle" in summary + assert "timestamps" in summary + + # Check current stats + current_stats = summary["current_stats"] + assert "total_size" in current_stats + assert "distribution" in current_stats + assert "hit_rate" in current_stats + + # Check performance metrics + performance = summary["performance"] + assert performance["access_count"] == 2 + assert performance["hit_rate"] == 0.5 + assert performance["miss_rate"] == 0.5 + + # Check lifecycle metrics + lifecycle = summary["lifecycle"] + assert lifecycle["promotions"] == 1 + + # Check timestamps + timestamps = summary["timestamps"] + assert "creation_time" in timestamps + assert "last_access" in timestamps \ No newline at end of file From 696ff1d48e92b19ff52b6464ba84b1b3a4b99a4f Mon Sep 17 00:00:00 2001 From: yike5460 Date: Thu, 7 Aug 2025 17:28:55 +0800 Subject: [PATCH 2/2] fix: resolve linting and type checking errors * Add missing docstrings for __init__ methods * Fix variable name collision with unused loop variable * Add Optional typing import and proper type annotations * Create thresholds property to handle Optional MemoryThresholds * Fix return type annotation for optimize_memory method * Ensure proper type safety throughout memory management system --- src/strands/agent/memory/__init__.py | 6 +- src/strands/agent/memory/config.py | 80 +++++---- src/strands/agent/memory/enhanced_state.py | 106 ++++++------ src/strands/agent/memory/lifecycle.py | 161 ++++++++++-------- src/strands/agent/memory/metrics.py | 106 ++++++------ tests/strands/agent/memory/__init__.py | 2 +- tests/strands/agent/memory/test_config.py | 28 ++- .../agent/memory/test_enhanced_state.py | 114 ++++++------- tests/strands/agent/memory/test_lifecycle.py | 147 ++++++++-------- tests/strands/agent/memory/test_metrics.py | 104 ++++++----- 10 files changed, 424 insertions(+), 430 deletions(-) diff --git a/src/strands/agent/memory/__init__.py b/src/strands/agent/memory/__init__.py index 89d41e134..bc26f2dff 100644 --- a/src/strands/agent/memory/__init__.py +++ b/src/strands/agent/memory/__init__.py @@ -11,17 +11,17 @@ capabilities for complex multi-agent scenarios. """ -from .config import MemoryConfig, MemoryCategory, MemoryThresholds +from .config import MemoryCategory, MemoryConfig, MemoryThresholds from .enhanced_state import EnhancedAgentState from .lifecycle import MemoryLifecycleManager from .metrics import MemoryMetrics, MemoryUsageStats __all__ = [ "MemoryConfig", - "MemoryCategory", + "MemoryCategory", "MemoryThresholds", "EnhancedAgentState", "MemoryLifecycleManager", "MemoryMetrics", "MemoryUsageStats", -] \ No newline at end of file +] diff --git a/src/strands/agent/memory/config.py b/src/strands/agent/memory/config.py index a427dd850..99580e331 100644 --- a/src/strands/agent/memory/config.py +++ b/src/strands/agent/memory/config.py @@ -7,59 +7,59 @@ class MemoryCategory(Enum): """Categories for memory classification.""" - - ACTIVE = "active" # Currently active/working memory - CACHED = "cached" # Recently used but not active - ARCHIVED = "archived" # Historical data for potential retrieval - METADATA = "metadata" # System metadata and statistics + + ACTIVE = "active" # Currently active/working memory + CACHED = "cached" # Recently used but not active + ARCHIVED = "archived" # Historical data for potential retrieval + METADATA = "metadata" # System metadata and statistics @dataclass class MemoryThresholds: """Memory management thresholds and limits.""" - + # Size thresholds (in estimated tokens/bytes) - active_memory_limit: int = 8192 # ~8K tokens for active memory - cached_memory_limit: int = 32768 # ~32K tokens for cached memory - total_memory_limit: int = 131072 # ~128K tokens total limit - + active_memory_limit: int = 8192 # ~8K tokens for active memory + cached_memory_limit: int = 32768 # ~32K tokens for cached memory + total_memory_limit: int = 131072 # ~128K tokens total limit + # Cleanup thresholds (percentages) - cleanup_threshold: float = 0.8 # Start cleanup at 80% of limit - emergency_threshold: float = 0.95 # Emergency cleanup at 95% - + cleanup_threshold: float = 0.8 # Start cleanup at 80% of limit + emergency_threshold: float = 0.95 # Emergency cleanup at 95% + # Time-based thresholds (seconds) - cache_ttl: int = 3600 # Cache TTL: 1 hour - archive_after: int = 86400 # Archive after: 24 hours - + cache_ttl: int = 3600 # Cache TTL: 1 hour + archive_after: int = 86400 # Archive after: 24 hours + # Cleanup ratios (how much to remove during cleanup) - cleanup_ratio: float = 0.3 # Remove 30% during cleanup - emergency_cleanup_ratio: float = 0.5 # Remove 50% during emergency + cleanup_ratio: float = 0.3 # Remove 30% during cleanup + emergency_cleanup_ratio: float = 0.5 # Remove 50% during emergency @dataclass class MemoryConfig: """Configuration for memory management system.""" - + # Feature toggles - enable_categorization: bool = True # Enable memory categorization - enable_lifecycle: bool = True # Enable automatic lifecycle management - enable_metrics: bool = True # Enable memory metrics collection - enable_archival: bool = True # Enable memory archival - + enable_categorization: bool = True # Enable memory categorization + enable_lifecycle: bool = True # Enable automatic lifecycle management + enable_metrics: bool = True # Enable memory metrics collection + enable_archival: bool = True # Enable memory archival + # Thresholds configuration - thresholds: MemoryThresholds = None - + thresholds: Optional[MemoryThresholds] = None + # Cleanup strategy - cleanup_strategy: str = "lru" # LRU, FIFO, or custom - + cleanup_strategy: str = "lru" # LRU, FIFO, or custom + # Validation settings - strict_validation: bool = True # Strict JSON validation - - def __post_init__(self): + strict_validation: bool = True # Strict JSON validation + + def __post_init__(self) -> None: """Initialize default thresholds if not provided.""" if self.thresholds is None: self.thresholds = MemoryThresholds() - + @classmethod def conservative(cls) -> "MemoryConfig": """Create conservative memory configuration with lower limits.""" @@ -69,10 +69,10 @@ def conservative(cls) -> "MemoryConfig": cached_memory_limit=16384, total_memory_limit=65536, cleanup_threshold=0.7, - cleanup_ratio=0.4 + cleanup_ratio=0.4, ) ) - + @classmethod def aggressive(cls) -> "MemoryConfig": """Create aggressive memory configuration with higher limits.""" @@ -82,10 +82,10 @@ def aggressive(cls) -> "MemoryConfig": cached_memory_limit=65536, total_memory_limit=262144, cleanup_threshold=0.9, - cleanup_ratio=0.2 + cleanup_ratio=0.2, ) ) - + @classmethod def minimal(cls) -> "MemoryConfig": """Create minimal memory configuration with basic features only.""" @@ -93,9 +93,5 @@ def minimal(cls) -> "MemoryConfig": enable_lifecycle=False, enable_metrics=False, enable_archival=False, - thresholds=MemoryThresholds( - active_memory_limit=2048, - cached_memory_limit=8192, - total_memory_limit=32768 - ) - ) \ No newline at end of file + thresholds=MemoryThresholds(active_memory_limit=2048, cached_memory_limit=8192, total_memory_limit=32768), + ) diff --git a/src/strands/agent/memory/enhanced_state.py b/src/strands/agent/memory/enhanced_state.py index c09f9415c..3289add0a 100644 --- a/src/strands/agent/memory/enhanced_state.py +++ b/src/strands/agent/memory/enhanced_state.py @@ -10,68 +10,64 @@ class EnhancedAgentState(AgentState): """Enhanced AgentState with memory categorization, lifecycle management, and metrics. - + This class extends the base AgentState to provide: - Memory categorization (active, cached, archived, metadata) - Automatic memory lifecycle management - Memory usage monitoring and metrics - Configurable memory thresholds and cleanup policies - + The enhanced state maintains full backward compatibility with the base AgentState interface while adding advanced memory management capabilities. """ - - def __init__( - self, - initial_state: Optional[Dict[str, Any]] = None, - memory_config: Optional[MemoryConfig] = None - ): + + def __init__(self, initial_state: Optional[Dict[str, Any]] = None, memory_config: Optional[MemoryConfig] = None): """Initialize EnhancedAgentState. - + Args: initial_state: Initial state dictionary (backward compatibility) memory_config: Memory management configuration """ # Initialize base AgentState for backward compatibility super().__init__(initial_state) - + # Initialize memory management self.memory_config = memory_config or MemoryConfig() self.memory_manager = MemoryLifecycleManager(self.memory_config) - + # Migrate existing state to memory manager if provided if initial_state: for key, value in initial_state.items(): self.memory_manager.add_item(key, value, MemoryCategory.ACTIVE) - + def set(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE) -> None: """Set a value in the state with optional memory category. - + Args: key: The key to store the value under - value: The value to store (must be JSON serializable) + value: The value to store (must be JSON serializable) category: Memory category for the value - + Raises: ValueError: If key is invalid, or if value is not JSON serializable """ # Validate using parent class methods self._validate_key(key) self._validate_json_serializable(value) - + # Store in both base state (for backward compatibility) and memory manager super().set(key, value) - + if self.memory_config.enable_categorization: self.memory_manager.add_item(key, value, category) - + def get(self, key: Optional[str] = None, category: Optional[MemoryCategory] = None) -> Any: """Get a value or entire state, optionally filtered by category. - + Args: key: The key to retrieve (if None, returns entire state object) category: Optional memory category filter - + Returns: The stored value, filtered state dict, or None if not found """ @@ -88,83 +84,83 @@ def get(self, key: Optional[str] = None, category: Optional[MemoryCategory] = No return self.memory_manager.get_item(key) else: return super().get(key) - + def delete(self, key: str) -> None: """Delete a specific key from the state. - + Args: key: The key to delete """ self._validate_key(key) - + # Delete from both base state and memory manager super().delete(key) - + if self.memory_config.enable_categorization: self.memory_manager.remove_item(key) - + def get_by_category(self, category: MemoryCategory) -> Dict[str, Any]: """Get all items in a specific memory category. - + Args: category: The memory category to retrieve - + Returns: Dictionary of all items in the specified category """ if not self.memory_config.enable_categorization: # If categorization disabled, return all items for any category return super().get() or {} - + return self.memory_manager.get_items_by_category(category) - + def set_metadata(self, key: str, value: Any) -> None: """Set a metadata value (convenience method). - + Args: key: The key to store the metadata under value: The metadata value """ self.set(key, value, MemoryCategory.METADATA) - + def get_active_memory(self) -> Dict[str, Any]: """Get all active memory items (convenience method). - + Returns: Dictionary of all active memory items """ return self.get_by_category(MemoryCategory.ACTIVE) - + def get_cached_memory(self) -> Dict[str, Any]: """Get all cached memory items (convenience method). - + Returns: Dictionary of all cached memory items """ return self.get_by_category(MemoryCategory.CACHED) - + def cleanup_memory(self, force: bool = False) -> int: """Perform memory cleanup and return number of items removed. - + Args: force: Force cleanup even if lifecycle management is disabled - + Returns: Number of items removed during cleanup """ if not self.memory_config.enable_lifecycle and not force: return 0 - + removed_count = self.memory_manager.cleanup_memory(force) - + # Sync base state with memory manager after cleanup self._sync_base_state() - + return removed_count - + def get_memory_stats(self) -> Dict[str, Any]: """Get comprehensive memory usage statistics. - + Returns: Dictionary containing memory usage statistics and metrics """ @@ -175,47 +171,47 @@ def get_memory_stats(self) -> Dict[str, Any]: "total_items": len(all_items), "categories_enabled": False, "lifecycle_enabled": self.memory_config.enable_lifecycle, - "metrics_enabled": False + "metrics_enabled": False, } - + return self.memory_manager.get_memory_report() - + def optimize_memory(self) -> Dict[str, Any]: """Optimize memory usage and return optimization results. - + Returns: Dictionary containing optimization results and statistics """ if not self.memory_config.enable_lifecycle: return {"optimization_skipped": True, "reason": "lifecycle_disabled"} - + optimization_results = self.memory_manager.optimize_memory() - + # Sync base state after optimization self._sync_base_state() - + return optimization_results - + def _sync_base_state(self) -> None: """Synchronize base state with memory manager state.""" if self.memory_config.enable_categorization: # Update base state to match memory manager all_items = self.memory_manager.get_all_items() self._state = copy.deepcopy(all_items) - + def configure_memory(self, config: MemoryConfig) -> None: """Update memory configuration. - + Args: config: New memory configuration """ self.memory_config = config self.memory_manager.config = config - + def get_memory_config(self) -> MemoryConfig: """Get current memory configuration. - + Returns: Current memory configuration """ - return self.memory_config \ No newline at end of file + return self.memory_config diff --git a/src/strands/agent/memory/lifecycle.py b/src/strands/agent/memory/lifecycle.py index bfc126d90..3f25eaafe 100644 --- a/src/strands/agent/memory/lifecycle.py +++ b/src/strands/agent/memory/lifecycle.py @@ -1,26 +1,33 @@ """Memory lifecycle management for automatic cleanup and archival.""" import time -from typing import Any, Dict, List, Optional, Protocol, Tuple +from typing import Any, Dict, Optional, Protocol -from .config import MemoryCategory, MemoryConfig +from .config import MemoryCategory, MemoryConfig, MemoryThresholds from .metrics import MemoryMetrics class MemoryItem(Protocol): """Protocol for memory items that can be managed by lifecycle manager.""" - + category: MemoryCategory created_at: float last_accessed: float access_count: int size: int - + class CategorizedMemoryItem: """A memory item with categorization and lifecycle metadata.""" - + def __init__(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE): + """Initialize a categorized memory item. + + Args: + key: The key for this memory item + value: The value to store + category: Memory category for this item + """ self.key = key self.value = value self.category = category @@ -28,185 +35,189 @@ def __init__(self, key: str, value: Any, category: MemoryCategory = MemoryCatego self.last_accessed = time.time() self.access_count = 0 self.size = self._estimate_size(value) - + def _estimate_size(self, value: Any) -> int: """Estimate the size of the value.""" import json + try: - return len(json.dumps(value).encode('utf-8')) + return len(json.dumps(value).encode("utf-8")) except (TypeError, ValueError): - return len(str(value).encode('utf-8')) - + return len(str(value).encode("utf-8")) + def access(self) -> None: """Record an access to this memory item.""" self.last_accessed = time.time() self.access_count += 1 - + def age(self) -> float: """Get the age of this memory item in seconds.""" return time.time() - self.created_at - + def idle_time(self) -> float: """Get the idle time since last access in seconds.""" return time.time() - self.last_accessed - + def should_demote(self, inactive_threshold: float) -> bool: """Check if item should be demoted from active to cached.""" - return (self.category == MemoryCategory.ACTIVE and - self.idle_time() > inactive_threshold) - + return self.category == MemoryCategory.ACTIVE and self.idle_time() > inactive_threshold + def should_archive(self, archive_threshold: float) -> bool: """Check if item should be archived.""" - return (self.category == MemoryCategory.CACHED and - self.age() > archive_threshold) + return self.category == MemoryCategory.CACHED and self.age() > archive_threshold class MemoryLifecycleManager: """Manages the lifecycle of memory items with automatic cleanup and archival.""" - + def __init__(self, config: MemoryConfig): + """Initialize the memory lifecycle manager. + + Args: + config: Memory management configuration + """ self.config = config + # Ensure thresholds are initialized + assert config.thresholds is not None, "MemoryConfig must have initialized thresholds" self.metrics = MemoryMetrics() self._items: Dict[str, CategorizedMemoryItem] = {} self._last_cleanup = time.time() - + + @property + def thresholds(self) -> MemoryThresholds: + """Get memory thresholds, ensuring they are never None.""" + assert self.config.thresholds is not None + return self.config.thresholds + def add_item(self, key: str, value: Any, category: MemoryCategory = MemoryCategory.ACTIVE) -> None: """Add or update a memory item.""" item = CategorizedMemoryItem(key, value, category) self._items[key] = item - + if category == MemoryCategory.ACTIVE: self.metrics.record_promotion() - + # Trigger cleanup if needed if self.config.enable_lifecycle: self._check_and_cleanup() - + def get_item(self, key: str) -> Optional[Any]: """Get a memory item by key, updating access statistics.""" if key in self._items: item = self._items[key] item.access() self.metrics.record_access(hit=True) - + # Promote to active if accessed frequently - if (item.category == MemoryCategory.CACHED and - self._should_promote(item)): + if item.category == MemoryCategory.CACHED and self._should_promote(item): item.category = MemoryCategory.ACTIVE self.metrics.record_promotion() - + return item.value else: self.metrics.record_access(hit=False) return None - + def remove_item(self, key: str) -> bool: """Remove a memory item.""" if key in self._items: del self._items[key] return True return False - + def get_items_by_category(self, category: MemoryCategory) -> Dict[str, Any]: """Get all items in a specific category.""" - return { - key: item.value - for key, item in self._items.items() - if item.category == category - } - + return {key: item.value for key, item in self._items.items() if item.category == category} + def get_all_items(self) -> Dict[str, Any]: """Get all memory items (backward compatibility with AgentState).""" return {key: item.value for key, item in self._items.items()} - + def cleanup_memory(self, force: bool = False) -> int: """Perform memory cleanup and return number of items removed.""" if not self.config.enable_lifecycle and not force: return 0 - + removed_count = 0 current_time = time.time() - + # First, demote old active items to cached for item in list(self._items.values()): - if item.should_demote(self.config.thresholds.cache_ttl): + if item.should_demote(self.thresholds.cache_ttl): item.category = MemoryCategory.CACHED self.metrics.record_demotion() - + # Archive old cached items if self.config.enable_archival: - for key, item in list(self._items.items()): - if item.should_archive(self.config.thresholds.archive_after): + for _key, item in list(self._items.items()): + if item.should_archive(self.thresholds.archive_after): item.category = MemoryCategory.ARCHIVED self.metrics.record_archival() - + # Remove items if over memory limits total_size = sum(item.size for item in self._items.values()) - if force or total_size > self.config.thresholds.total_memory_limit: + if force or total_size > self.thresholds.total_memory_limit: removed_count += self._emergency_cleanup() - + self.metrics.record_cleanup() self._last_cleanup = current_time self._update_metrics() - + return removed_count - + def _check_and_cleanup(self) -> None: """Check if cleanup is needed and perform it.""" total_size = sum(item.size for item in self._items.values()) - utilization = total_size / self.config.thresholds.total_memory_limit - - if utilization >= self.config.thresholds.cleanup_threshold: + utilization = total_size / self.thresholds.total_memory_limit + + if utilization >= self.thresholds.cleanup_threshold: self.cleanup_memory() - + # Update metrics periodically current_time = time.time() if current_time - self._last_cleanup > 300: # Update every 5 minutes self._update_metrics() self._last_cleanup = current_time - + def _should_promote(self, item: CategorizedMemoryItem) -> bool: """Determine if a cached item should be promoted to active.""" # Promote if accessed frequently in recent time recent_accesses = item.access_count age_hours = item.age() / 3600 - + # Simple heuristic: if accessed more than once per hour on average return age_hours > 0 and (recent_accesses / age_hours) > 1.0 - + def _emergency_cleanup(self) -> int: """Perform emergency cleanup when memory limits are exceeded.""" removed_count = 0 - + # Sort items by priority (LRU-like: least recently used first) - items_by_priority = sorted( - self._items.items(), - key=lambda x: (x[1].category.value, x[1].last_accessed) - ) - + items_by_priority = sorted(self._items.items(), key=lambda x: (x[1].category.value, x[1].last_accessed)) + # Calculate how many items to remove total_items = len(self._items) - target_removal = max(1, int(total_items * self.config.thresholds.emergency_cleanup_ratio)) - + target_removal = max(1, int(total_items * self.thresholds.emergency_cleanup_ratio)) + # Remove least important items first for key, item in items_by_priority[:target_removal]: # Don't remove active metadata if item.category != MemoryCategory.METADATA: del self._items[key] removed_count += 1 - + return removed_count - + def _update_metrics(self) -> None: """Update memory usage statistics.""" from .metrics import MemoryUsageStats - + stats = MemoryUsageStats() - + for item in self._items.values(): stats.total_size += item.size stats.total_items += 1 - + if item.category == MemoryCategory.ACTIVE: stats.active_size += item.size stats.active_items += 1 @@ -219,25 +230,25 @@ def _update_metrics(self) -> None: elif item.category == MemoryCategory.METADATA: stats.metadata_size += item.size stats.metadata_items += 1 - + self.metrics.update_stats(stats) - + def get_memory_report(self) -> Dict[str, Any]: """Get a comprehensive memory usage report.""" self._update_metrics() return self.metrics.get_summary() - - def optimize_memory(self) -> Dict[str, int]: + + def optimize_memory(self) -> Dict[str, Any]: """Perform memory optimization and return optimization results.""" initial_size = sum(item.size for item in self._items.values()) initial_count = len(self._items) - + # Perform cleanup removed_count = self.cleanup_memory(force=True) - + final_size = sum(item.size for item in self._items.values()) final_count = len(self._items) - + return { "initial_size": initial_size, "final_size": final_size, @@ -245,5 +256,5 @@ def optimize_memory(self) -> Dict[str, int]: "initial_count": initial_count, "final_count": final_count, "items_removed": removed_count, - "size_reduction_pct": ((initial_size - final_size) / initial_size * 100) if initial_size > 0 else 0 - } \ No newline at end of file + "size_reduction_pct": int(((initial_size - final_size) / initial_size * 100)) if initial_size > 0 else 0, + } diff --git a/src/strands/agent/memory/metrics.py b/src/strands/agent/memory/metrics.py index e1e6b6691..739edc85c 100644 --- a/src/strands/agent/memory/metrics.py +++ b/src/strands/agent/memory/metrics.py @@ -11,43 +11,43 @@ @dataclass class MemoryUsageStats: """Statistics for memory usage tracking.""" - + # Size metrics (estimated tokens/bytes) total_size: int = 0 active_size: int = 0 cached_size: int = 0 archived_size: int = 0 metadata_size: int = 0 - + # Count metrics total_items: int = 0 active_items: int = 0 cached_items: int = 0 archived_items: int = 0 metadata_items: int = 0 - + # Performance metrics - hit_rate: float = 0.0 # Cache hit rate - miss_rate: float = 0.0 # Cache miss rate - cleanup_count: int = 0 # Number of cleanups performed + hit_rate: float = 0.0 # Cache hit rate + miss_rate: float = 0.0 # Cache miss rate + cleanup_count: int = 0 # Number of cleanups performed last_cleanup: Optional[float] = None # Timestamp of last cleanup - + # Lifecycle metrics - promotions: int = 0 # Items promoted to active - demotions: int = 0 # Items demoted from active - archival_count: int = 0 # Items archived - + promotions: int = 0 # Items promoted to active + demotions: int = 0 # Items demoted from active + archival_count: int = 0 # Items archived + def utilization_ratio(self, limit: int) -> float: """Calculate memory utilization ratio.""" if limit == 0: return 0.0 return min(1.0, self.total_size / limit) - + def category_distribution(self) -> Dict[str, float]: """Get distribution of memory across categories.""" if self.total_size == 0: return {cat.value: 0.0 for cat in MemoryCategory} - + return { MemoryCategory.ACTIVE.value: self.active_size / self.total_size, MemoryCategory.CACHED.value: self.cached_size / self.total_size, @@ -56,66 +56,66 @@ def category_distribution(self) -> Dict[str, float]: } -@dataclass +@dataclass class MemoryMetrics: """Memory metrics collection and analysis.""" - + # Current statistics stats: MemoryUsageStats = field(default_factory=MemoryUsageStats) - + # Historical data (last N measurements) history: List[MemoryUsageStats] = field(default_factory=list) max_history_size: int = 100 - + # Access tracking access_count: int = 0 hit_count: int = 0 miss_count: int = 0 - + # Timing metrics last_access_time: Optional[float] = None creation_time: float = field(default_factory=time.time) - + def record_access(self, hit: bool = True) -> None: """Record a memory access (hit or miss).""" self.access_count += 1 self.last_access_time = time.time() - + if hit: self.hit_count += 1 else: self.miss_count += 1 - + # Update hit/miss rates if self.access_count > 0: self.stats.hit_rate = self.hit_count / self.access_count self.stats.miss_rate = self.miss_count / self.access_count - + def record_cleanup(self) -> None: """Record a memory cleanup operation.""" self.stats.cleanup_count += 1 self.stats.last_cleanup = time.time() - + def record_promotion(self) -> None: """Record a memory item promotion.""" self.stats.promotions += 1 - + def record_demotion(self) -> None: """Record a memory item demotion.""" self.stats.demotions += 1 - + def record_archival(self) -> None: """Record a memory item archival.""" self.stats.archival_count += 1 - + def update_stats(self, new_stats: MemoryUsageStats) -> None: """Update current statistics and save to history.""" # Save current stats to history if len(self.history) >= self.max_history_size: self.history.pop(0) # Remove oldest entry - + self.history.append(self.stats) - + # Preserve accumulated metrics from current stats new_stats.hit_rate = self.stats.hit_rate new_stats.miss_rate = self.stats.miss_rate @@ -124,98 +124,98 @@ def update_stats(self, new_stats: MemoryUsageStats) -> None: new_stats.demotions = self.stats.demotions new_stats.archival_count = self.stats.archival_count new_stats.last_cleanup = self.stats.last_cleanup - + self.stats = new_stats - + def estimate_item_size(self, value: Any) -> int: """Estimate the size of a memory item in bytes/tokens.""" try: # Use JSON serialization size as rough estimate json_str = json.dumps(value) # Rough token estimate: ~4 characters per token - return len(json_str.encode('utf-8')) + return len(json_str.encode("utf-8")) except (TypeError, ValueError): # Fallback for non-serializable objects - return len(str(value).encode('utf-8')) - + return len(str(value).encode("utf-8")) + def get_trend_analysis(self, window: int = 10) -> Dict[str, float]: """Analyze trends in memory usage over the last N measurements.""" if len(self.history) < 1: return {"trend": 0.0, "volatility": 0.0} - + # Get all stats (history + current), excluding initial empty stats all_history = self.history + [self.stats] # Skip the first item if it's the initial empty stats (total_size == 0) if all_history and all_history[0].total_size == 0 and len(all_history) > 1: all_history = all_history[1:] - + # Apply window to the complete set all_stats = all_history[-window:] if len(all_history) >= window else all_history - + # Calculate trend (simple linear regression slope) if len(all_stats) < 2: return {"trend": 0.0, "volatility": 0.0} - + sizes = [stats.total_size for stats in all_stats] n = len(sizes) x_mean = (n - 1) / 2 y_mean = sum(sizes) / n - + numerator = sum((i - x_mean) * (sizes[i] - y_mean) for i in range(n)) denominator = sum((i - x_mean) ** 2 for i in range(n)) - + trend = numerator / denominator if denominator != 0 else 0.0 - + # Calculate volatility (standard deviation of sizes) variance = sum((size - y_mean) ** 2 for size in sizes) / n - volatility = variance ** 0.5 - + volatility = variance**0.5 + return { "trend": trend, "volatility": volatility, "avg_size": y_mean, "min_size": min(sizes), - "max_size": max(sizes) + "max_size": max(sizes), } - + def should_cleanup(self, threshold: float, limit: int) -> bool: """Determine if memory cleanup should be triggered.""" utilization = self.stats.utilization_ratio(limit) return utilization >= threshold - + def get_summary(self) -> Dict[str, Any]: """Get a comprehensive summary of memory metrics.""" trend_analysis = self.get_trend_analysis() - + return { "current_stats": { "total_size": self.stats.total_size, "total_items": self.stats.total_items, "distribution": self.stats.category_distribution(), "hit_rate": self.stats.hit_rate, - "cleanup_count": self.stats.cleanup_count + "cleanup_count": self.stats.cleanup_count, }, "performance": { "access_count": self.access_count, "hit_rate": self.hit_count / self.access_count if self.access_count > 0 else 0.0, "miss_rate": self.miss_count / self.access_count if self.access_count > 0 else 0.0, - "avg_response_time": self._calculate_avg_response_time() + "avg_response_time": self._calculate_avg_response_time(), }, "trends": trend_analysis, "lifecycle": { "promotions": self.stats.promotions, "demotions": self.stats.demotions, - "archival_count": self.stats.archival_count + "archival_count": self.stats.archival_count, }, "timestamps": { "creation_time": self.creation_time, "last_access": self.last_access_time, - "last_cleanup": self.stats.last_cleanup - } + "last_cleanup": self.stats.last_cleanup, + }, } - + def _calculate_avg_response_time(self) -> Optional[float]: """Calculate average response time based on access patterns.""" # Placeholder for future implementation # Could track timing data for memory operations - return None \ No newline at end of file + return None diff --git a/tests/strands/agent/memory/__init__.py b/tests/strands/agent/memory/__init__.py index 3f50f6ae7..f93d43409 100644 --- a/tests/strands/agent/memory/__init__.py +++ b/tests/strands/agent/memory/__init__.py @@ -1 +1 @@ -"""Tests for memory management system.""" \ No newline at end of file +"""Tests for memory management system.""" diff --git a/tests/strands/agent/memory/test_config.py b/tests/strands/agent/memory/test_config.py index a028ba0aa..b2e97390d 100644 --- a/tests/strands/agent/memory/test_config.py +++ b/tests/strands/agent/memory/test_config.py @@ -1,7 +1,5 @@ """Tests for memory management configuration.""" -import pytest - from strands.agent.memory.config import MemoryCategory, MemoryConfig, MemoryThresholds @@ -16,7 +14,7 @@ def test_memory_category_enum(): def test_memory_thresholds_defaults(): """Test MemoryThresholds default values.""" thresholds = MemoryThresholds() - + assert thresholds.active_memory_limit == 8192 assert thresholds.cached_memory_limit == 32768 assert thresholds.total_memory_limit == 131072 @@ -35,9 +33,9 @@ def test_memory_thresholds_custom_values(): cached_memory_limit=16384, total_memory_limit=65536, cleanup_threshold=0.7, - emergency_threshold=0.9 + emergency_threshold=0.9, ) - + assert thresholds.active_memory_limit == 4096 assert thresholds.cached_memory_limit == 16384 assert thresholds.total_memory_limit == 65536 @@ -48,7 +46,7 @@ def test_memory_thresholds_custom_values(): def test_memory_config_defaults(): """Test MemoryConfig default values.""" config = MemoryConfig() - + assert config.enable_categorization is True assert config.enable_lifecycle is True assert config.enable_metrics is True @@ -63,7 +61,7 @@ def test_memory_config_with_custom_thresholds(): """Test MemoryConfig with custom thresholds.""" thresholds = MemoryThresholds(active_memory_limit=2048) config = MemoryConfig(thresholds=thresholds) - + assert config.thresholds.active_memory_limit == 2048 assert config.thresholds.cached_memory_limit == 32768 # Default value @@ -71,7 +69,7 @@ def test_memory_config_with_custom_thresholds(): def test_memory_config_conservative(): """Test conservative memory configuration.""" config = MemoryConfig.conservative() - + assert config.thresholds.active_memory_limit == 4096 assert config.thresholds.cached_memory_limit == 16384 assert config.thresholds.total_memory_limit == 65536 @@ -82,7 +80,7 @@ def test_memory_config_conservative(): def test_memory_config_aggressive(): """Test aggressive memory configuration.""" config = MemoryConfig.aggressive() - + assert config.thresholds.active_memory_limit == 16384 assert config.thresholds.cached_memory_limit == 65536 assert config.thresholds.total_memory_limit == 262144 @@ -93,12 +91,12 @@ def test_memory_config_aggressive(): def test_memory_config_minimal(): """Test minimal memory configuration.""" config = MemoryConfig.minimal() - + assert config.enable_lifecycle is False assert config.enable_metrics is False assert config.enable_archival is False assert config.enable_categorization is True # Still enabled by default - + assert config.thresholds.active_memory_limit == 2048 assert config.thresholds.cached_memory_limit == 8192 assert config.thresholds.total_memory_limit == 32768 @@ -112,9 +110,9 @@ def test_memory_config_custom_features(): enable_metrics=False, enable_archival=False, cleanup_strategy="fifo", - strict_validation=False + strict_validation=False, ) - + assert config.enable_categorization is False assert config.enable_lifecycle is False assert config.enable_metrics is False @@ -129,9 +127,9 @@ def test_memory_config_post_init(): config = MemoryConfig(thresholds=None) assert config.thresholds is not None assert isinstance(config.thresholds, MemoryThresholds) - + # Test with provided thresholds (should keep them) custom_thresholds = MemoryThresholds(active_memory_limit=1024) config = MemoryConfig(thresholds=custom_thresholds) assert config.thresholds is custom_thresholds - assert config.thresholds.active_memory_limit == 1024 \ No newline at end of file + assert config.thresholds.active_memory_limit == 1024 diff --git a/tests/strands/agent/memory/test_enhanced_state.py b/tests/strands/agent/memory/test_enhanced_state.py index af500213b..905051465 100644 --- a/tests/strands/agent/memory/test_enhanced_state.py +++ b/tests/strands/agent/memory/test_enhanced_state.py @@ -9,7 +9,7 @@ def test_enhanced_agent_state_initialization(): """Test EnhancedAgentState initialization.""" state = EnhancedAgentState() - + assert state.memory_config is not None assert state.memory_manager is not None assert isinstance(state.memory_config, MemoryConfig) @@ -19,11 +19,11 @@ def test_enhanced_agent_state_with_initial_state(): """Test EnhancedAgentState initialization with initial state.""" initial_state = {"key1": "value1", "key2": "value2"} state = EnhancedAgentState(initial_state=initial_state) - + # Should be available through both interfaces assert state.get("key1") == "value1" assert state.get("key2") == "value2" - + # Should be available through parent interface (backward compatibility) parent_state = super(EnhancedAgentState, state).get() assert parent_state["key1"] == "value1" @@ -34,7 +34,7 @@ def test_enhanced_agent_state_with_memory_config(): """Test EnhancedAgentState with custom memory configuration.""" config = MemoryConfig.conservative() state = EnhancedAgentState(memory_config=config) - + assert state.memory_config is config assert state.memory_manager.config is config @@ -42,7 +42,7 @@ def test_enhanced_agent_state_with_memory_config(): def test_enhanced_agent_state_set_get_basic(): """Test basic set and get operations.""" state = EnhancedAgentState() - + state.set("test_key", "test_value") assert state.get("test_key") == "test_value" @@ -50,11 +50,11 @@ def test_enhanced_agent_state_set_get_basic(): def test_enhanced_agent_state_set_with_category(): """Test set operation with memory category.""" state = EnhancedAgentState() - + state.set("active_key", "active_value", MemoryCategory.ACTIVE) state.set("cached_key", "cached_value", MemoryCategory.CACHED) state.set("metadata_key", "metadata_value", MemoryCategory.METADATA) - + assert state.get("active_key") == "active_value" assert state.get("cached_key") == "cached_value" assert state.get("metadata_key") == "metadata_value" @@ -63,15 +63,15 @@ def test_enhanced_agent_state_set_with_category(): def test_enhanced_agent_state_get_by_category(): """Test retrieving items by category.""" state = EnhancedAgentState() - + state.set("active1", "value1", MemoryCategory.ACTIVE) state.set("active2", "value2", MemoryCategory.ACTIVE) state.set("cached1", "value3", MemoryCategory.CACHED) - + active_items = state.get_by_category(MemoryCategory.ACTIVE) cached_items = state.get_by_category(MemoryCategory.CACHED) archived_items = state.get_by_category(MemoryCategory.ARCHIVED) - + assert len(active_items) == 2 assert len(cached_items) == 1 assert len(archived_items) == 0 @@ -83,10 +83,10 @@ def test_enhanced_agent_state_get_by_category_disabled(): """Test get_by_category when categorization is disabled.""" config = MemoryConfig(enable_categorization=False) state = EnhancedAgentState(memory_config=config) - + state.set("key1", "value1") state.set("key2", "value2") - + # Should return all items regardless of category when disabled items = state.get_by_category(MemoryCategory.ACTIVE) assert len(items) == 2 @@ -97,10 +97,10 @@ def test_enhanced_agent_state_get_by_category_disabled(): def test_enhanced_agent_state_get_entire_state(): """Test getting entire state.""" state = EnhancedAgentState() - + state.set("key1", "value1", MemoryCategory.ACTIVE) state.set("key2", "value2", MemoryCategory.CACHED) - + # Get entire state all_items = state.get() assert len(all_items) == 2 @@ -111,15 +111,15 @@ def test_enhanced_agent_state_get_entire_state(): def test_enhanced_agent_state_get_filtered_by_category(): """Test getting state filtered by category.""" state = EnhancedAgentState() - + state.set("active1", "value1", MemoryCategory.ACTIVE) state.set("cached1", "value2", MemoryCategory.CACHED) - + # Get only active items active_items = state.get(category=MemoryCategory.ACTIVE) assert len(active_items) == 1 assert active_items["active1"] == "value1" - + # Get only cached items cached_items = state.get(category=MemoryCategory.CACHED) assert len(cached_items) == 1 @@ -129,15 +129,15 @@ def test_enhanced_agent_state_get_filtered_by_category(): def test_enhanced_agent_state_delete(): """Test deleting items.""" state = EnhancedAgentState() - + state.set("key1", "value1") state.set("key2", "value2") - + state.delete("key1") - + assert state.get("key1") is None assert state.get("key2") == "value2" - + # Should also be deleted from parent state parent_state = super(EnhancedAgentState, state).get() assert "key1" not in parent_state @@ -147,17 +147,17 @@ def test_enhanced_agent_state_delete(): def test_enhanced_agent_state_convenience_methods(): """Test convenience methods for memory categories.""" state = EnhancedAgentState() - + # Test set_metadata state.set_metadata("meta_key", "meta_value") metadata_items = state.get_by_category(MemoryCategory.METADATA) assert metadata_items["meta_key"] == "meta_value" - + # Test get_active_memory state.set("active_key", "active_value", MemoryCategory.ACTIVE) active_items = state.get_active_memory() assert active_items["active_key"] == "active_value" - + # Test get_cached_memory state.set("cached_key", "cached_value", MemoryCategory.CACHED) cached_items = state.get_cached_memory() @@ -167,14 +167,14 @@ def test_enhanced_agent_state_convenience_methods(): def test_enhanced_agent_state_cleanup_memory(): """Test memory cleanup functionality.""" state = EnhancedAgentState() - + state.set("key1", "value1") state.set("key2", "value2") - + # Test cleanup (might not remove anything without time passage) removed_count = state.cleanup_memory() assert removed_count >= 0 - + # Test forced cleanup removed_count = state.cleanup_memory(force=True) assert removed_count >= 0 @@ -184,13 +184,13 @@ def test_enhanced_agent_state_cleanup_memory_disabled(): """Test cleanup when lifecycle management is disabled.""" config = MemoryConfig(enable_lifecycle=False) state = EnhancedAgentState(memory_config=config) - + state.set("key1", "value1") - + # Should return 0 when lifecycle disabled removed_count = state.cleanup_memory() assert removed_count == 0 - + # Should work with force=True removed_count = state.cleanup_memory(force=True) assert removed_count >= 0 @@ -199,19 +199,19 @@ def test_enhanced_agent_state_cleanup_memory_disabled(): def test_enhanced_agent_state_get_memory_stats(): """Test getting memory statistics.""" state = EnhancedAgentState() - + state.set("key1", "value1", MemoryCategory.ACTIVE) state.set("key2", "value2", MemoryCategory.CACHED) - + stats = state.get_memory_stats() - + assert "current_stats" in stats or "total_items" in stats - + # Test with metrics disabled config = MemoryConfig(enable_metrics=False) state_no_metrics = EnhancedAgentState(memory_config=config) stats = state_no_metrics.get_memory_stats() - + assert stats["metrics_enabled"] is False assert "total_items" in stats @@ -219,20 +219,20 @@ def test_enhanced_agent_state_get_memory_stats(): def test_enhanced_agent_state_optimize_memory(): """Test memory optimization.""" state = EnhancedAgentState() - + state.set("key1", "value1") state.set("key2", "value2") - + optimization_results = state.optimize_memory() - + # Should return optimization results or skip info assert isinstance(optimization_results, dict) - + # Test with lifecycle disabled config = MemoryConfig(enable_lifecycle=False) state_no_lifecycle = EnhancedAgentState(memory_config=config) results = state_no_lifecycle.optimize_memory() - + assert results.get("optimization_skipped") is True assert results.get("reason") == "lifecycle_disabled" @@ -241,10 +241,10 @@ def test_enhanced_agent_state_configure_memory(): """Test memory configuration updates.""" state = EnhancedAgentState() original_config = state.get_memory_config() - + new_config = MemoryConfig.aggressive() state.configure_memory(new_config) - + assert state.get_memory_config() is new_config assert state.memory_manager.config is new_config assert state.get_memory_config() is not original_config @@ -253,23 +253,23 @@ def test_enhanced_agent_state_configure_memory(): def test_enhanced_agent_state_backward_compatibility(): """Test backward compatibility with base AgentState interface.""" state = EnhancedAgentState() - + # All base AgentState operations should work state.set("key1", "value1") assert state.get("key1") == "value1" - + # Get entire state should work all_state = state.get() assert all_state["key1"] == "value1" - + # Delete should work state.delete("key1") assert state.get("key1") is None - + # Validation should still work with pytest.raises(ValueError, match="Key cannot be None"): state.set(None, "value") - + with pytest.raises(ValueError, match="not JSON serializable"): state.set("key", lambda x: x) @@ -278,15 +278,15 @@ def test_enhanced_agent_state_with_categorization_disabled(): """Test behavior when categorization is disabled.""" config = MemoryConfig(enable_categorization=False) state = EnhancedAgentState(memory_config=config) - + # Set operations should still work state.set("key1", "value1") state.set("key2", "value2", MemoryCategory.CACHED) # Category ignored - + # Get should work normally assert state.get("key1") == "value1" assert state.get("key2") == "value2" - + # Should use parent implementation when categorization disabled all_items = state.get() assert all_items["key1"] == "value1" @@ -296,7 +296,7 @@ def test_enhanced_agent_state_with_categorization_disabled(): def test_enhanced_agent_state_json_validation(): """Test that JSON validation is maintained.""" state = EnhancedAgentState() - + # Valid JSON types should work state.set("string", "test") state.set("int", 42) @@ -304,11 +304,11 @@ def test_enhanced_agent_state_json_validation(): state.set("list", [1, 2, 3]) state.set("dict", {"nested": "value"}) state.set("null", None) - + # Invalid JSON types should raise ValueError with pytest.raises(ValueError, match="not JSON serializable"): state.set("function", lambda x: x) - + with pytest.raises(ValueError, match="not JSON serializable"): state.set("object", object()) @@ -316,13 +316,13 @@ def test_enhanced_agent_state_json_validation(): def test_enhanced_agent_state_key_validation(): """Test that key validation is maintained.""" state = EnhancedAgentState() - + # Invalid keys should raise ValueError with pytest.raises(ValueError, match="Key cannot be None"): state.set(None, "value") - + with pytest.raises(ValueError, match="Key cannot be empty"): state.set("", "value") - + with pytest.raises(ValueError, match="Key must be a string"): - state.set(123, "value") \ No newline at end of file + state.set(123, "value") diff --git a/tests/strands/agent/memory/test_lifecycle.py b/tests/strands/agent/memory/test_lifecycle.py index 15895e7d4..549f845b4 100644 --- a/tests/strands/agent/memory/test_lifecycle.py +++ b/tests/strands/agent/memory/test_lifecycle.py @@ -3,8 +3,6 @@ import time from unittest.mock import patch -import pytest - from strands.agent.memory.config import MemoryCategory, MemoryConfig, MemoryThresholds from strands.agent.memory.lifecycle import CategorizedMemoryItem, MemoryLifecycleManager @@ -12,7 +10,7 @@ def test_categorized_memory_item_creation(): """Test CategorizedMemoryItem creation and properties.""" item = CategorizedMemoryItem("test_key", "test_value") - + assert item.key == "test_key" assert item.value == "test_value" assert item.category == MemoryCategory.ACTIVE @@ -25,7 +23,7 @@ def test_categorized_memory_item_creation(): def test_categorized_memory_item_with_category(): """Test CategorizedMemoryItem with specific category.""" item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.CACHED) - + assert item.category == MemoryCategory.CACHED @@ -34,33 +32,33 @@ def test_categorized_memory_item_access(): item = CategorizedMemoryItem("test_key", "test_value") initial_access_time = item.last_accessed initial_access_count = item.access_count - + time.sleep(0.01) # Small delay to ensure time difference item.access() - + assert item.access_count == initial_access_count + 1 assert item.last_accessed > initial_access_time def test_categorized_memory_item_age(): """Test memory item age calculation.""" - with patch('time.time') as mock_time: + with patch("time.time") as mock_time: mock_time.return_value = 1000.0 item = CategorizedMemoryItem("test_key", "test_value") - + mock_time.return_value = 1010.0 assert item.age() == 10.0 def test_categorized_memory_item_idle_time(): """Test memory item idle time calculation.""" - with patch('time.time') as mock_time: + with patch("time.time") as mock_time: mock_time.return_value = 1000.0 item = CategorizedMemoryItem("test_key", "test_value") - + mock_time.return_value = 1005.0 item.access() - + mock_time.return_value = 1015.0 assert item.idle_time() == 10.0 @@ -68,32 +66,32 @@ def test_categorized_memory_item_idle_time(): def test_categorized_memory_item_should_demote(): """Test demotion logic for memory items.""" item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.ACTIVE) - - with patch.object(item, 'idle_time', return_value=3700): # > 1 hour + + with patch.object(item, "idle_time", return_value=3700): # > 1 hour assert item.should_demote(3600) is True - - with patch.object(item, 'idle_time', return_value=1800): # < 1 hour + + with patch.object(item, "idle_time", return_value=1800): # < 1 hour assert item.should_demote(3600) is False - + # Non-active items should not be demoted item.category = MemoryCategory.CACHED - with patch.object(item, 'idle_time', return_value=3700): + with patch.object(item, "idle_time", return_value=3700): assert item.should_demote(3600) is False def test_categorized_memory_item_should_archive(): """Test archival logic for memory items.""" item = CategorizedMemoryItem("test_key", "test_value", MemoryCategory.CACHED) - - with patch.object(item, 'age', return_value=90000): # > 24 hours + + with patch.object(item, "age", return_value=90000): # > 24 hours assert item.should_archive(86400) is True - - with patch.object(item, 'age', return_value=43200): # < 24 hours + + with patch.object(item, "age", return_value=43200): # < 24 hours assert item.should_archive(86400) is False - + # Non-cached items should not be archived item.category = MemoryCategory.ACTIVE - with patch.object(item, 'age', return_value=90000): + with patch.object(item, "age", return_value=90000): assert item.should_archive(86400) is False @@ -103,11 +101,11 @@ def test_categorized_memory_item_size_estimation(): string_item = CategorizedMemoryItem("key", "hello") dict_item = CategorizedMemoryItem("key", {"nested": "data"}) list_item = CategorizedMemoryItem("key", [1, 2, 3, 4, 5]) - + assert string_item.size > 0 assert dict_item.size > 0 assert list_item.size > 0 - + # Larger objects should have larger estimated sizes large_string = "x" * 1000 large_item = CategorizedMemoryItem("key", large_string) @@ -118,7 +116,7 @@ def test_memory_lifecycle_manager_initialization(): """Test MemoryLifecycleManager initialization.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + assert manager.config is config assert manager._items == {} assert manager._last_cleanup <= time.time() @@ -128,10 +126,10 @@ def test_memory_lifecycle_manager_add_item(): """Test adding items to memory manager.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + manager.add_item("key1", "value1") manager.add_item("key2", "value2", MemoryCategory.CACHED) - + assert "key1" in manager._items assert "key2" in manager._items assert manager._items["key1"].category == MemoryCategory.ACTIVE @@ -144,17 +142,17 @@ def test_memory_lifecycle_manager_get_item(): """Test retrieving items from memory manager.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + manager.add_item("existing_key", "existing_value") - + # Test existing key value = manager.get_item("existing_key") assert value == "existing_value" - + # Test non-existing key value = manager.get_item("non_existing_key") assert value is None - + # Check that access was recorded item = manager._items["existing_key"] assert item.access_count == 1 @@ -164,13 +162,13 @@ def test_memory_lifecycle_manager_remove_item(): """Test removing items from memory manager.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + manager.add_item("key1", "value1") - + # Remove existing item assert manager.remove_item("key1") is True assert "key1" not in manager._items - + # Remove non-existing item assert manager.remove_item("non_existing") is False @@ -179,15 +177,15 @@ def test_memory_lifecycle_manager_get_items_by_category(): """Test retrieving items by category.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + manager.add_item("active1", "value1", MemoryCategory.ACTIVE) manager.add_item("active2", "value2", MemoryCategory.ACTIVE) manager.add_item("cached1", "value3", MemoryCategory.CACHED) - + active_items = manager.get_items_by_category(MemoryCategory.ACTIVE) cached_items = manager.get_items_by_category(MemoryCategory.CACHED) archived_items = manager.get_items_by_category(MemoryCategory.ARCHIVED) - + assert len(active_items) == 2 assert len(cached_items) == 1 assert len(archived_items) == 0 @@ -199,12 +197,12 @@ def test_memory_lifecycle_manager_get_all_items(): """Test retrieving all items.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + manager.add_item("key1", "value1", MemoryCategory.ACTIVE) manager.add_item("key2", "value2", MemoryCategory.CACHED) - + all_items = manager.get_all_items() - + assert len(all_items) == 2 assert all_items["key1"] == "value1" assert all_items["key2"] == "value2" @@ -214,16 +212,15 @@ def test_memory_lifecycle_manager_promotion(): """Test automatic promotion of frequently accessed items.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + # Add cached item manager.add_item("key1", "value1", MemoryCategory.CACHED) - item = manager._items["key1"] - + # Mock frequent access pattern - with patch.object(manager, '_should_promote', return_value=True): + with patch.object(manager, "_should_promote", return_value=True): # Access should trigger promotion manager.get_item("key1") - + # Check that item was promoted assert manager._items["key1"].category == MemoryCategory.ACTIVE @@ -232,20 +229,20 @@ def test_memory_lifecycle_manager_should_promote(): """Test promotion logic.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + # Create item with access pattern item = CategorizedMemoryItem("key", "value", MemoryCategory.CACHED) - + # Test promotion logic - with patch.object(item, 'age', return_value=7200): # 2 hours - with patch.object(item, 'access_count', 3): # 3 accesses in 2 hours > 1/hour + with patch.object(item, "age", return_value=7200): # 2 hours + with patch.object(item, "access_count", 3): # 3 accesses in 2 hours > 1/hour assert manager._should_promote(item) is True - - with patch.object(item, 'access_count', 1): # 1 access in 2 hours < 1/hour + + with patch.object(item, "access_count", 1): # 1 access in 2 hours < 1/hour assert manager._should_promote(item) is False - + # Test with new item (age = 0) - with patch.object(item, 'age', return_value=0): + with patch.object(item, "age", return_value=0): assert manager._should_promote(item) is False @@ -253,24 +250,24 @@ def test_memory_lifecycle_manager_cleanup_memory(): """Test memory cleanup functionality.""" thresholds = MemoryThresholds( cache_ttl=1800, # 30 minutes - archive_after=3600 # 1 hour + archive_after=3600, # 1 hour ) config = MemoryConfig(thresholds=thresholds, enable_lifecycle=True) manager = MemoryLifecycleManager(config) - + # Add items with different ages manager.add_item("active_old", "value1", MemoryCategory.ACTIVE) manager.add_item("cached_old", "value2", MemoryCategory.CACHED) - + # Mock ages for demotion/archival active_item = manager._items["active_old"] cached_item = manager._items["cached_old"] - - with patch.object(active_item, 'should_demote', return_value=True): - with patch.object(cached_item, 'should_archive', return_value=True): + + with patch.object(active_item, "should_demote", return_value=True): + with patch.object(cached_item, "should_archive", return_value=True): # Run cleanup manager.cleanup_memory() - + # Check changes assert active_item.category == MemoryCategory.CACHED assert cached_item.category == MemoryCategory.ARCHIVED @@ -281,18 +278,18 @@ def test_memory_lifecycle_manager_emergency_cleanup(): thresholds = MemoryThresholds(total_memory_limit=100) # Very small limit config = MemoryConfig(thresholds=thresholds, enable_lifecycle=True) manager = MemoryLifecycleManager(config) - + # Add items that exceed limit large_value = "x" * 50 manager.add_item("key1", large_value, MemoryCategory.CACHED) manager.add_item("key2", large_value, MemoryCategory.CACHED) manager.add_item("key3", large_value, MemoryCategory.ACTIVE) - + initial_count = len(manager._items) - + # Force cleanup removed_count = manager.cleanup_memory(force=True) - + # Should have removed some items assert removed_count > 0 assert len(manager._items) < initial_count @@ -302,13 +299,13 @@ def test_memory_lifecycle_manager_disabled_lifecycle(): """Test behavior when lifecycle management is disabled.""" config = MemoryConfig(enable_lifecycle=False) manager = MemoryLifecycleManager(config) - + manager.add_item("key1", "value1") - + # Cleanup should do nothing when lifecycle disabled removed_count = manager.cleanup_memory(force=False) assert removed_count == 0 - + # But should work with force=True removed_count = manager.cleanup_memory(force=True) assert removed_count >= 0 # May or may not remove items @@ -318,12 +315,12 @@ def test_memory_lifecycle_manager_get_memory_report(): """Test memory usage report generation.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + manager.add_item("key1", "value1", MemoryCategory.ACTIVE) manager.add_item("key2", "value2", MemoryCategory.CACHED) - + report = manager.get_memory_report() - + assert "current_stats" in report assert "performance" in report assert "trends" in report @@ -335,17 +332,17 @@ def test_memory_lifecycle_manager_optimize_memory(): """Test memory optimization.""" config = MemoryConfig() manager = MemoryLifecycleManager(config) - + # Add some items manager.add_item("key1", "value1") manager.add_item("key2", "value2") - + optimization_results = manager.optimize_memory() - + assert "initial_size" in optimization_results assert "final_size" in optimization_results assert "size_reduction" in optimization_results assert "initial_count" in optimization_results assert "final_count" in optimization_results assert "items_removed" in optimization_results - assert "size_reduction_pct" in optimization_results \ No newline at end of file + assert "size_reduction_pct" in optimization_results diff --git a/tests/strands/agent/memory/test_metrics.py b/tests/strands/agent/memory/test_metrics.py index 94f5c1911..af747e8a3 100644 --- a/tests/strands/agent/memory/test_metrics.py +++ b/tests/strands/agent/memory/test_metrics.py @@ -1,10 +1,6 @@ """Tests for memory metrics and monitoring.""" -import json import time -from unittest.mock import patch - -import pytest from strands.agent.memory.config import MemoryCategory from strands.agent.memory.metrics import MemoryMetrics, MemoryUsageStats @@ -13,7 +9,7 @@ def test_memory_usage_stats_defaults(): """Test MemoryUsageStats default values.""" stats = MemoryUsageStats() - + assert stats.total_size == 0 assert stats.active_size == 0 assert stats.cached_size == 0 @@ -29,13 +25,13 @@ def test_memory_usage_stats_utilization_ratio(): """Test utilization ratio calculation.""" stats = MemoryUsageStats() stats.total_size = 5000 - + # Normal case assert stats.utilization_ratio(10000) == 0.5 - - # Over limit case + + # Over limit case assert stats.utilization_ratio(2500) == 1.0 - + # Zero limit case assert stats.utilization_ratio(0) == 0.0 @@ -48,9 +44,9 @@ def test_memory_usage_stats_category_distribution(): stats.cached_size = 300 stats.archived_size = 200 stats.metadata_size = 100 - + distribution = stats.category_distribution() - + assert distribution[MemoryCategory.ACTIVE.value] == 0.4 assert distribution[MemoryCategory.CACHED.value] == 0.3 assert distribution[MemoryCategory.ARCHIVED.value] == 0.2 @@ -61,9 +57,9 @@ def test_memory_usage_stats_category_distribution_empty(): """Test category distribution with zero total size.""" stats = MemoryUsageStats() # total_size = 0 by default - + distribution = stats.category_distribution() - + for category in MemoryCategory: assert distribution[category.value] == 0.0 @@ -71,7 +67,7 @@ def test_memory_usage_stats_category_distribution_empty(): def test_memory_metrics_initialization(): """Test MemoryMetrics initialization.""" metrics = MemoryMetrics() - + assert isinstance(metrics.stats, MemoryUsageStats) assert metrics.history == [] assert metrics.max_history_size == 100 @@ -85,45 +81,45 @@ def test_memory_metrics_initialization(): def test_memory_metrics_record_access(): """Test recording memory access (hits and misses).""" metrics = MemoryMetrics() - + # Record hits metrics.record_access(hit=True) metrics.record_access(hit=True) - + assert metrics.access_count == 2 assert metrics.hit_count == 2 assert metrics.miss_count == 0 assert metrics.stats.hit_rate == 1.0 assert metrics.stats.miss_rate == 0.0 assert metrics.last_access_time is not None - + # Record miss metrics.record_access(hit=False) - + assert metrics.access_count == 3 assert metrics.hit_count == 2 assert metrics.miss_count == 1 - assert metrics.stats.hit_rate == 2/3 - assert metrics.stats.miss_rate == 1/3 + assert metrics.stats.hit_rate == 2 / 3 + assert metrics.stats.miss_rate == 1 / 3 def test_memory_metrics_record_operations(): """Test recording various memory operations.""" metrics = MemoryMetrics() - + # Test cleanup recording metrics.record_cleanup() assert metrics.stats.cleanup_count == 1 assert metrics.stats.last_cleanup is not None - + # Test promotion recording metrics.record_promotion() assert metrics.stats.promotions == 1 - + # Test demotion recording metrics.record_demotion() assert metrics.stats.demotions == 1 - + # Test archival recording metrics.record_archival() assert metrics.stats.archival_count == 1 @@ -132,19 +128,19 @@ def test_memory_metrics_record_operations(): def test_memory_metrics_update_stats(): """Test updating statistics with history tracking.""" metrics = MemoryMetrics() - + # Create new stats new_stats = MemoryUsageStats() new_stats.total_size = 1000 new_stats.total_items = 10 - + # Update stats metrics.update_stats(new_stats) - + # Check that old stats were saved to history assert len(metrics.history) == 1 assert metrics.history[0].total_size == 0 # Old default stats - + # Check that new stats are current assert metrics.stats.total_size == 1000 assert metrics.stats.total_items == 10 @@ -154,16 +150,16 @@ def test_memory_metrics_history_limit(): """Test that history is limited to max_history_size.""" metrics = MemoryMetrics() metrics.max_history_size = 3 - + # Add more stats than the limit for i in range(5): new_stats = MemoryUsageStats() new_stats.total_size = i * 100 metrics.update_stats(new_stats) - + # Check that history is limited assert len(metrics.history) == 3 - + # Check that oldest entries were removed (should have sizes 100, 200, 300) assert metrics.history[0].total_size == 100 assert metrics.history[1].total_size == 200 @@ -173,22 +169,22 @@ def test_memory_metrics_history_limit(): def test_memory_metrics_estimate_item_size(): """Test item size estimation.""" metrics = MemoryMetrics() - + # Test with JSON-serializable objects assert metrics.estimate_item_size("hello") > 0 assert metrics.estimate_item_size({"key": "value"}) > 0 assert metrics.estimate_item_size([1, 2, 3]) > 0 - + # Test that larger objects have larger estimated sizes small_obj = "hi" large_obj = "hello world" * 100 assert metrics.estimate_item_size(large_obj) > metrics.estimate_item_size(small_obj) - + # Test with non-serializable object class CustomObject: def __str__(self): return "custom object" - + custom_obj = CustomObject() size = metrics.estimate_item_size(custom_obj) assert size > 0 # Should fallback to string representation @@ -197,19 +193,19 @@ def __str__(self): def test_memory_metrics_trend_analysis(): """Test trend analysis functionality.""" metrics = MemoryMetrics() - + # Test with insufficient history trend = metrics.get_trend_analysis() assert trend["trend"] == 0.0 assert trend["volatility"] == 0.0 - + # Add some history with increasing sizes sizes = [100, 200, 300, 400, 500] for size in sizes: stats = MemoryUsageStats() stats.total_size = size metrics.update_stats(stats) - + trend = metrics.get_trend_analysis() assert trend["trend"] > 0 # Should show upward trend assert trend["avg_size"] == 300 # Average of 100, 200, 300, 400, 500 @@ -221,14 +217,14 @@ def test_memory_metrics_trend_analysis(): def test_memory_metrics_trend_analysis_with_window(): """Test trend analysis with custom window size.""" metrics = MemoryMetrics() - + # Add history sizes = [100, 200, 300, 400, 500, 600] for size in sizes: stats = MemoryUsageStats() stats.total_size = size metrics.update_stats(stats) - + # Analyze with smaller window trend = metrics.get_trend_analysis(window=3) # Should only consider last 3 values: 400, 500, 600 @@ -240,16 +236,16 @@ def test_memory_metrics_trend_analysis_with_window(): def test_memory_metrics_should_cleanup(): """Test cleanup decision logic.""" metrics = MemoryMetrics() - + # Setup stats metrics.stats.total_size = 8000 - + # Should cleanup when over threshold assert metrics.should_cleanup(threshold=0.8, limit=10000) is True - - # Should not cleanup when under threshold + + # Should not cleanup when under threshold assert metrics.should_cleanup(threshold=0.9, limit=10000) is False - + # Edge case: zero limit assert metrics.should_cleanup(threshold=0.5, limit=0) is False @@ -257,45 +253,45 @@ def test_memory_metrics_should_cleanup(): def test_memory_metrics_get_summary(): """Test comprehensive summary generation.""" metrics = MemoryMetrics() - + # Setup some data metrics.record_access(hit=True) metrics.record_access(hit=False) metrics.record_cleanup() metrics.record_promotion() - + # Add some history for trend analysis for i in range(3): stats = MemoryUsageStats() stats.total_size = (i + 1) * 100 metrics.update_stats(stats) - + summary = metrics.get_summary() - + # Check structure assert "current_stats" in summary assert "performance" in summary assert "trends" in summary assert "lifecycle" in summary assert "timestamps" in summary - + # Check current stats current_stats = summary["current_stats"] assert "total_size" in current_stats assert "distribution" in current_stats assert "hit_rate" in current_stats - + # Check performance metrics performance = summary["performance"] assert performance["access_count"] == 2 assert performance["hit_rate"] == 0.5 assert performance["miss_rate"] == 0.5 - + # Check lifecycle metrics lifecycle = summary["lifecycle"] assert lifecycle["promotions"] == 1 - + # Check timestamps timestamps = summary["timestamps"] assert "creation_time" in timestamps - assert "last_access" in timestamps \ No newline at end of file + assert "last_access" in timestamps