diff --git a/src/strands/agent/context/__init__.py b/src/strands/agent/context/__init__.py new file mode 100644 index 000000000..a611dc3f4 --- /dev/null +++ b/src/strands/agent/context/__init__.py @@ -0,0 +1,27 @@ +"""Intelligent context management for optimized agent interactions. + +This module provides advanced context management capabilities including: +- Dynamic tool selection based on context and task requirements +- Context window optimization and intelligent pruning +- Tool usage analytics and performance tracking +- Relevance-based filtering and scoring + +The context management system works alongside the memory management +to provide efficient and intelligent agent interactions. +""" + +from .analytics import ToolUsageAnalytics, ToolUsageStats +from .context_optimizer import ContextOptimizer, ContextWindow +from .relevance_scoring import RelevanceScorer, SimilarityMetric +from .tool_manager import DynamicToolManager, ToolSelectionCriteria + +__all__ = [ + "DynamicToolManager", + "ToolSelectionCriteria", + "ContextOptimizer", + "ContextWindow", + "RelevanceScorer", + "SimilarityMetric", + "ToolUsageAnalytics", + "ToolUsageStats", +] diff --git a/src/strands/agent/context/analytics.py b/src/strands/agent/context/analytics.py new file mode 100644 index 000000000..874d21c9e --- /dev/null +++ b/src/strands/agent/context/analytics.py @@ -0,0 +1,205 @@ +"""Tool usage analytics and performance tracking.""" + +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class ToolUsageStats: + """Statistics for a single tool's usage.""" + + tool_name: str + total_calls: int = 0 + successful_calls: int = 0 + failed_calls: int = 0 + total_execution_time: float = 0.0 + last_used: Optional[float] = None + avg_relevance_score: float = 0.0 + relevance_scores: List[float] = field(default_factory=list) + + @property + def success_rate(self) -> float: + """Calculate success rate of tool calls.""" + if self.total_calls == 0: + return 0.0 + return self.successful_calls / self.total_calls + + @property + def avg_execution_time(self) -> float: + """Calculate average execution time per call.""" + if self.successful_calls == 0: + return 0.0 + return self.total_execution_time / self.successful_calls + + def record_usage(self, success: bool, execution_time: float, relevance_score: Optional[float] = None) -> None: + """Record a tool usage event. + + Args: + success: Whether the tool call was successful + execution_time: Time taken for execution in seconds + relevance_score: Relevance score if available + """ + self.total_calls += 1 + self.last_used = time.time() + + if success: + self.successful_calls += 1 + self.total_execution_time += execution_time + else: + self.failed_calls += 1 + + if relevance_score is not None: + self.relevance_scores.append(relevance_score) + # Update running average + self.avg_relevance_score = sum(self.relevance_scores) / len(self.relevance_scores) + + +@dataclass +class ContextPerformanceStats: + """Performance statistics for context management.""" + + total_context_builds: int = 0 + total_pruning_operations: int = 0 + avg_context_size: float = 0.0 + avg_pruning_ratio: float = 0.0 + context_sizes: List[int] = field(default_factory=list) + pruning_ratios: List[float] = field(default_factory=list) + + def record_context_build(self, context_size: int, original_size: int) -> None: + """Record a context building operation. + + Args: + context_size: Final context size + original_size: Original size before optimization + """ + self.total_context_builds += 1 + self.context_sizes.append(context_size) + + if original_size > 0: + pruning_ratio = 1.0 - (context_size / original_size) + self.pruning_ratios.append(pruning_ratio) + self.total_pruning_operations += 1 + + # Update averages + self.avg_context_size = sum(self.context_sizes) / len(self.context_sizes) + if self.pruning_ratios: + self.avg_pruning_ratio = sum(self.pruning_ratios) / len(self.pruning_ratios) + + +class ToolUsageAnalytics: + """Tracks and analyzes tool usage patterns for optimization.""" + + def __init__(self) -> None: + """Initialize tool usage analytics.""" + self.tool_stats: Dict[str, ToolUsageStats] = {} + self.context_stats = ContextPerformanceStats() + self._start_time = time.time() + + def get_tool_stats(self, tool_name: str) -> ToolUsageStats: + """Get or create stats for a tool. + + Args: + tool_name: Name of the tool + + Returns: + Tool usage statistics + """ + if tool_name not in self.tool_stats: + self.tool_stats[tool_name] = ToolUsageStats(tool_name=tool_name) + return self.tool_stats[tool_name] + + def record_tool_usage( + self, tool_name: str, success: bool, execution_time: float, relevance_score: Optional[float] = None + ) -> None: + """Record a tool usage event. + + Args: + tool_name: Name of the tool used + success: Whether the execution was successful + execution_time: Time taken for execution + relevance_score: Relevance score if available + """ + stats = self.get_tool_stats(tool_name) + stats.record_usage(success, execution_time, relevance_score) + + def record_context_build(self, context_size: int, original_size: int) -> None: + """Record context building statistics. + + Args: + context_size: Final optimized context size + original_size: Original context size before optimization + """ + self.context_stats.record_context_build(context_size, original_size) + + def get_tool_rankings(self, min_calls: int = 5) -> List[Tuple[str, float]]: + """Get tools ranked by performance score. + + Args: + min_calls: Minimum calls required for ranking + + Returns: + List of (tool_name, score) tuples sorted by score + """ + rankings = [] + + for tool_name, stats in self.tool_stats.items(): + if stats.total_calls >= min_calls: + # Composite score based on success rate, relevance, and recency + recency_factor = self._calculate_recency_factor(stats.last_used) + performance_score = 0.4 * stats.success_rate + 0.4 * stats.avg_relevance_score + 0.2 * recency_factor + rankings.append((tool_name, performance_score)) + + rankings.sort(key=lambda x: x[1], reverse=True) + return rankings + + def _calculate_recency_factor(self, last_used: Optional[float]) -> float: + """Calculate recency factor for tool usage. + + Args: + last_used: Timestamp of last usage + + Returns: + Recency factor between 0.0 and 1.0 + """ + if last_used is None: + return 0.0 + + # Decay over 24 hours + time_since_use = time.time() - last_used + decay_period = 24 * 3600 # 24 hours in seconds + + return max(0.0, 1.0 - (time_since_use / decay_period)) + + def get_summary_report(self) -> Dict[str, Any]: + """Generate comprehensive analytics summary. + + Returns: + Dictionary containing analytics summary + """ + total_tools = len(self.tool_stats) + total_calls = sum(stats.total_calls for stats in self.tool_stats.values()) + + if total_calls > 0: + overall_success_rate = sum(stats.successful_calls for stats in self.tool_stats.values()) / total_calls + else: + overall_success_rate = 0.0 + + return { + "uptime_seconds": time.time() - self._start_time, + "total_tools_used": total_tools, + "total_tool_calls": total_calls, + "overall_success_rate": overall_success_rate, + "context_optimization": { + "total_builds": self.context_stats.total_context_builds, + "avg_context_size": self.context_stats.avg_context_size, + "avg_pruning_ratio": self.context_stats.avg_pruning_ratio, + }, + "top_tools": self.get_tool_rankings()[:5], + } + + def reset_stats(self) -> None: + """Reset all analytics data.""" + self.tool_stats.clear() + self.context_stats = ContextPerformanceStats() + self._start_time = time.time() diff --git a/src/strands/agent/context/context_optimizer.py b/src/strands/agent/context/context_optimizer.py new file mode 100644 index 000000000..8cb426c58 --- /dev/null +++ b/src/strands/agent/context/context_optimizer.py @@ -0,0 +1,276 @@ +"""Context window optimization and intelligent pruning.""" + +import json +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +from .relevance_scoring import ContextRelevanceFilter, RelevanceScorer, TextRelevanceScorer + + +@dataclass +class ContextItem: + """A single item in the context window.""" + + key: str + value: Any + size: int # Estimated size in tokens/characters + relevance_score: float = 0.0 + timestamp: float = 0.0 + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class ContextWindow: + """Represents an optimized context window.""" + + items: List[ContextItem] + total_size: int + max_size: int + optimization_stats: Dict[str, Any] + + @property + def utilization(self) -> float: + """Calculate context window utilization.""" + if self.max_size == 0: + return 0.0 + return self.total_size / self.max_size + + def to_dict(self) -> Dict[str, Any]: + """Convert context window to dictionary format.""" + return {item.key: item.value for item in self.items} + + +class ContextOptimizer: + """Optimizes context windows for efficient agent interactions.""" + + def __init__( + self, max_context_size: int = 8192, relevance_threshold: float = 0.3, scorer: Optional[RelevanceScorer] = None + ): + """Initialize context optimizer. + + Args: + max_context_size: Maximum context size in tokens/characters + relevance_threshold: Minimum relevance score to include + scorer: Relevance scorer to use + """ + self.max_context_size = max_context_size + self.relevance_threshold = relevance_threshold + self.scorer = scorer or TextRelevanceScorer() + self.relevance_filter = ContextRelevanceFilter(self.scorer) + + def optimize_context( + self, + context_items: Dict[str, Any], + task_description: str, + required_keys: Optional[List[str]] = None, + size_estimator: Optional[Callable[[Any], int]] = None, + ) -> ContextWindow: + """Optimize context window for a specific task. + + Args: + context_items: All available context items + task_description: Description of the task/query + required_keys: Keys that must be included + size_estimator: Optional function to estimate item size + + Returns: + Optimized context window + """ + # Initialize size estimator + if size_estimator is None: + size_estimator = self._estimate_size + + # Score all items + scored_items = [] + for key, value in context_items.items(): + relevance = self.scorer.score(value, task_description) + size = size_estimator(value) + + item = ContextItem(key=key, value=value, size=size, relevance_score=relevance) + scored_items.append(item) + + # Separate required and optional items + required_items = [] + optional_items = [] + + for item in scored_items: + if required_keys and item.key in required_keys: + required_items.append(item) + else: + optional_items.append(item) + + # Build optimized context + optimized_items = self._build_optimized_context(required_items, optional_items, self.max_context_size) + + # Calculate total size + total_size = sum(item.size for item in optimized_items) + + # Generate optimization stats + stats = { + "original_items": len(context_items), + "optimized_items": len(optimized_items), + "original_size": sum(item.size for item in scored_items), + "optimized_size": total_size, + "pruning_ratio": 1.0 - (len(optimized_items) / len(context_items)) if context_items else 0.0, + "avg_relevance": ( + sum(item.relevance_score for item in optimized_items) / len(optimized_items) if optimized_items else 0.0 + ), + } + + return ContextWindow( + items=optimized_items, total_size=total_size, max_size=self.max_context_size, optimization_stats=stats + ) + + def _build_optimized_context( + self, required_items: List[ContextItem], optional_items: List[ContextItem], max_size: int + ) -> List[ContextItem]: + """Build optimized context with required and optional items. + + Args: + required_items: Items that must be included + optional_items: Items that may be included based on relevance + max_size: Maximum total size + + Returns: + List of items for optimized context + """ + # Start with required items, compressing if needed + context_items = [] + current_size = 0 + + # First add required items, compressing if they're too large + for item in required_items: + if item.size <= max_size: + context_items.append(item) + current_size += item.size + else: + # Required item is too large, must compress + compressed_item = self._try_compress_item(item, max_size) + if compressed_item: + context_items.append(compressed_item) + current_size += compressed_item.size + else: + # Can't compress enough, still add it (will exceed limit) + context_items.append(item) + current_size += item.size + + # Sort optional items by relevance + optional_items.sort(key=lambda x: x.relevance_score, reverse=True) + + # Add optional items that fit and meet threshold + for item in optional_items: + if item.relevance_score >= self.relevance_threshold: + if current_size + item.size <= max_size: + context_items.append(item) + current_size += item.size + else: + # Try compression strategies + compressed_item = self._try_compress_item(item, max_size - current_size) + if compressed_item and compressed_item.size <= max_size - current_size: + context_items.append(compressed_item) + current_size += compressed_item.size + + return context_items + + def _try_compress_item(self, item: ContextItem, target_size: int) -> Optional[ContextItem]: + """Try to compress an item to fit within target size. + + Args: + item: Item to compress + target_size: Target size limit + + Returns: + Compressed item or None if compression not possible + """ + if item.size <= target_size: + return item + + # Simple truncation strategy for strings + if isinstance(item.value, str): + # Estimate characters per token (rough approximation) + chars_per_token = 4 + target_chars = target_size * chars_per_token + + if len(item.value) > target_chars: + truncated_value = item.value[:target_chars] + "..." + return ContextItem( + key=item.key, + value=truncated_value, + size=self._estimate_size(truncated_value), + relevance_score=item.relevance_score * 0.8, # Reduce score for truncated + metadata={"truncated": True, "original_size": item.size}, + ) + + # For other types, we can't easily compress + return None + + def _estimate_size(self, value: Any) -> int: + """Estimate size of a value in tokens. + + Args: + value: Value to estimate + + Returns: + Estimated size in tokens + """ + # Simple character-based estimation + # Rough approximation: 1 token ≈ 4 characters + if isinstance(value, str): + return max(1, len(value) // 4) if value else 0 + elif isinstance(value, dict): + json_str = json.dumps(value) + return max(1, len(json_str) // 4) if json_str else 0 + elif isinstance(value, list): + json_str = json.dumps(value) + return max(1, len(json_str) // 4) if json_str else 0 + else: + str_value = str(value) + return max(1, len(str_value) // 4) if str_value else 0 + + def merge_contexts(self, contexts: List[ContextWindow], task_description: str) -> ContextWindow: + """Merge multiple context windows into one optimized window. + + Args: + contexts: List of context windows to merge + task_description: Task description for relevance scoring + + Returns: + Merged and optimized context window + """ + # Collect all items + all_items: Dict[str, Any] = {} + item_relevance: Dict[str, float] = {} + for context in contexts: + for item in context.items: + # Use highest relevance score if duplicate keys + if item.key in all_items: + if item.relevance_score > item_relevance[item.key]: + all_items[item.key] = item.value + item_relevance[item.key] = item.relevance_score + else: + all_items[item.key] = item.value + item_relevance[item.key] = item.relevance_score + + # Re-optimize merged context + return self.optimize_context(all_items, task_description) + + def get_pruning_recommendations(self, context_window: ContextWindow) -> List[Tuple[str, str]]: + """Get recommendations for further context pruning. + + Args: + context_window: Current context window + + Returns: + List of (item_key, recommendation) tuples + """ + recommendations = [] + + for item in context_window.items: + if item.relevance_score < 0.5: + recommendations.append((item.key, f"Low relevance ({item.relevance_score:.2f}), consider removing")) + elif item.size > context_window.max_size * 0.2: + recommendations.append((item.key, f"Large item ({item.size} tokens), consider summarizing")) + elif item.metadata and item.metadata.get("truncated"): + recommendations.append((item.key, "Item was truncated, consider using summary instead")) + + return recommendations diff --git a/src/strands/agent/context/relevance_scoring.py b/src/strands/agent/context/relevance_scoring.py new file mode 100644 index 000000000..fe63302a9 --- /dev/null +++ b/src/strands/agent/context/relevance_scoring.py @@ -0,0 +1,277 @@ +"""Relevance scoring and similarity calculations for context management.""" + +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional + + +class SimilarityMetric(Enum): + """Available similarity metrics for relevance scoring.""" + + JACCARD = "jaccard" + COSINE = "cosine" + LEVENSHTEIN = "levenshtein" + SEMANTIC = "semantic" # Future: requires embedding model + + +@dataclass +class ScoredItem: + """An item with its relevance score.""" + + key: str + value: Any + score: float + metadata: Optional[Dict[str, Any]] = None + + +class RelevanceScorer(ABC): + """Base class for relevance scoring implementations.""" + + @abstractmethod + def score(self, item: Any, context: Any) -> float: + """Calculate relevance score between item and context. + + Args: + item: The item to score + context: The context to score against + + Returns: + Relevance score between 0.0 and 1.0 + """ + pass + + +class TextRelevanceScorer(RelevanceScorer): + """Relevance scorer for text-based content using string similarity.""" + + def __init__(self, metric: SimilarityMetric = SimilarityMetric.JACCARD): + """Initialize text relevance scorer. + + Args: + metric: The similarity metric to use + """ + self.metric = metric + + def score(self, item: Any, context: Any) -> float: + """Calculate text relevance score. + + Args: + item: Text item to score + context: Context text to score against + + Returns: + Relevance score between 0.0 and 1.0 + """ + # Convert to strings + item_text = self._to_text(item) + context_text = self._to_text(context) + + if self.metric == SimilarityMetric.JACCARD: + return self._jaccard_similarity(item_text, context_text) + elif self.metric == SimilarityMetric.LEVENSHTEIN: + return self._levenshtein_similarity(item_text, context_text) + else: + # Default to Jaccard + return self._jaccard_similarity(item_text, context_text) + + def _to_text(self, value: Any) -> str: + """Convert any value to text representation.""" + if isinstance(value, str): + return value + elif isinstance(value, (dict, list)): + return json.dumps(value, sort_keys=True) + else: + return str(value) + + def _jaccard_similarity(self, text1: str, text2: str) -> float: + """Calculate Jaccard similarity between two texts.""" + # Tokenize by words + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 and not words2: + return 1.0 + if not words1 or not words2: + return 0.0 + + # Direct intersection + intersection = words1.intersection(words2) + + # Also count partial matches (e.g., "read" and "reads", "file" and "files") + partial_matches = 0.0 + for w1 in words1: + for w2 in words2: + if w1 != w2 and (w1 in w2 or w2 in w1) and min(len(w1), len(w2)) >= 3: + partial_matches += 0.5 + break + + union = words1.union(words2) + + # Calculate score with partial matches + score = (len(intersection) + partial_matches) / len(union) + return min(1.0, score) + + def _levenshtein_similarity(self, text1: str, text2: str) -> float: + """Calculate normalized Levenshtein similarity.""" + distance = self._levenshtein_distance(text1, text2) + max_len = max(len(text1), len(text2)) + + if max_len == 0: + return 1.0 + + return 1.0 - (distance / max_len) + + def _levenshtein_distance(self, s1: str, s2: str) -> int: + """Calculate Levenshtein distance between two strings.""" + if len(s1) < len(s2): + return self._levenshtein_distance(s2, s1) + + if len(s2) == 0: + return len(s1) + + previous_row = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + current_row = [i + 1] + for j, c2 in enumerate(s2): + insertions = previous_row[j + 1] + 1 + deletions = current_row[j] + 1 + substitutions = previous_row[j] + (c1 != c2) + current_row.append(min(insertions, deletions, substitutions)) + previous_row = current_row.copy() + + return previous_row[-1] + + +class ToolRelevanceScorer(RelevanceScorer): + """Relevance scorer specifically for tool selection.""" + + def __init__(self, text_scorer: Optional[TextRelevanceScorer] = None): + """Initialize tool relevance scorer. + + Args: + text_scorer: Text scorer to use for description matching + """ + self.text_scorer = text_scorer or TextRelevanceScorer() + + def score(self, item: Any, context: Any) -> float: + """Score tool relevance based on tool metadata and context. + + Args: + item: Tool or tool metadata + context: Task context or requirements + + Returns: + Relevance score between 0.0 and 1.0 + """ + # Extract tool information + tool_info = self._extract_tool_info(item) + context_info = self._extract_context_info(context) + + # Calculate component scores + tool_full_text = f"{tool_info.get('name', '')} {tool_info.get('description', '')}" + context_full_text = f"{context_info.get('task', '')} {context_info.get('requirements', '')}" + + # Score tool against full context + full_score = self.text_scorer.score(tool_full_text, context_full_text) + + # Also calculate individual component scores for fine-tuning + name_score = self.text_scorer.score(tool_info.get("name", ""), context_info.get("task", "")) + description_score = self.text_scorer.score( + tool_info.get("description", ""), context_info.get("requirements", "") + ) + + # Check for explicit tool mentions + if "required_tools" in context_info: + required = context_info["required_tools"] + if tool_info.get("name") in required: + return 1.0 # Maximum relevance for required tools + + # Weighted combination with emphasis on full text match + return 0.5 * full_score + 0.2 * name_score + 0.3 * description_score + + def _extract_tool_info(self, item: Any) -> Dict[str, Any]: + """Extract relevant information from tool object.""" + if isinstance(item, dict): + return item + + # Handle tool objects + info = {} + if hasattr(item, "tool_name"): + info["name"] = item.tool_name + elif hasattr(item, "name"): + info["name"] = item.name + + if hasattr(item, "tool_spec"): + info["description"] = item.tool_spec.get("description", "") + elif hasattr(item, "description"): + info["description"] = item.description + + if hasattr(item, "parameters"): + info["parameters"] = item.parameters + + return info + + def _extract_context_info(self, context: Any) -> Dict[str, Any]: + """Extract relevant information from context.""" + if isinstance(context, dict): + return context + elif isinstance(context, str): + return {"task": context, "requirements": context} + else: + return {"task": str(context)} + + +class ContextRelevanceFilter: + """Filters and ranks items based on relevance to context.""" + + def __init__(self, scorer: RelevanceScorer): + """Initialize relevance filter. + + Args: + scorer: The relevance scorer to use + """ + self.scorer = scorer + + def filter_relevant( + self, items: Dict[str, Any], context: Any, min_score: float = 0.3, max_items: Optional[int] = None + ) -> List[ScoredItem]: + """Filter items by relevance score. + + Args: + items: Dictionary of items to filter + context: Context to score against + min_score: Minimum relevance score threshold + max_items: Maximum number of items to return + + Returns: + List of scored items sorted by relevance + """ + scored_items = [] + + for key, value in items.items(): + score = self.scorer.score(value, context) + if score >= min_score: + scored_items.append(ScoredItem(key=key, value=value, score=score)) + + # Sort by score descending + scored_items.sort(key=lambda x: x.score, reverse=True) + + if max_items is not None: + scored_items = scored_items[:max_items] + + return scored_items + + def get_top_k(self, items: Dict[str, Any], context: Any, k: int = 5) -> List[ScoredItem]: + """Get top-k most relevant items. + + Args: + items: Dictionary of items to rank + context: Context to score against + k: Number of top items to return + + Returns: + Top k items by relevance score + """ + return self.filter_relevant(items, context, min_score=0.0, max_items=k) diff --git a/src/strands/agent/context/tool_manager.py b/src/strands/agent/context/tool_manager.py new file mode 100644 index 000000000..f96bf8c67 --- /dev/null +++ b/src/strands/agent/context/tool_manager.py @@ -0,0 +1,336 @@ +"""Dynamic tool management with intelligent selection and filtering.""" + +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple + +from ...types.tools import AgentTool as Tool +from .analytics import ToolUsageAnalytics +from .relevance_scoring import ToolRelevanceScorer + + +@dataclass +class ToolSelectionCriteria: + """Criteria for tool selection.""" + + task_description: str + required_capabilities: Optional[List[str]] = None + excluded_tools: Optional[Set[str]] = None + max_tools: int = 20 + min_relevance_score: float = 0.2 + prefer_recent: bool = True + context_hints: Optional[Dict[str, Any]] = None + + +@dataclass +class ToolSelectionResult: + """Result of tool selection process.""" + + selected_tools: List[Tool] + relevance_scores: Dict[str, float] + selection_reasoning: Dict[str, str] + total_candidates: int + selection_time: float + + +class DynamicToolManager: + """Manages dynamic tool selection based on context and performance.""" + + def __init__(self, analytics: Optional[ToolUsageAnalytics] = None, scorer: Optional[ToolRelevanceScorer] = None): + """Initialize dynamic tool manager. + + Args: + analytics: Tool usage analytics instance + scorer: Tool relevance scorer + """ + self.analytics = analytics or ToolUsageAnalytics() + self.scorer = scorer or ToolRelevanceScorer() + self._tool_cache: Dict[str, Tool] = {} + + def select_tools(self, available_tools: List[Tool], criteria: ToolSelectionCriteria) -> ToolSelectionResult: + """Select optimal tools based on criteria. + + Args: + available_tools: List of all available tools + criteria: Selection criteria + + Returns: + Tool selection result with selected tools and metadata + """ + start_time = time.time() + + # Update tool cache + self._update_tool_cache(available_tools) + + # Filter excluded tools + candidate_tools = [ + tool + for tool in available_tools + if not criteria.excluded_tools or tool.tool_name not in criteria.excluded_tools + ] + + # Score tools based on relevance + scored_tools = self._score_tools(candidate_tools, criteria) + + # Apply performance-based adjustments + if criteria.prefer_recent: + scored_tools = self._adjust_scores_by_performance(scored_tools) + + # Filter by minimum score and capability requirements + filtered_tools = self._filter_tools(scored_tools, criteria) + + # Select top tools up to max_tools limit + selected = filtered_tools[: criteria.max_tools] + + # Generate selection reasoning + reasoning = self._generate_selection_reasoning(selected, scored_tools, criteria) + + # Record analytics + selection_time = time.time() - start_time + + return ToolSelectionResult( + selected_tools=[tool for tool, _ in selected], + relevance_scores={tool.tool_name: score for tool, score in selected}, + selection_reasoning=reasoning, + total_candidates=len(available_tools), + selection_time=selection_time, + ) + + def _update_tool_cache(self, tools: List[Tool]) -> None: + """Update internal tool cache.""" + for tool in tools: + self._tool_cache[tool.tool_name] = tool + + def _score_tools(self, tools: List[Tool], criteria: ToolSelectionCriteria) -> List[Tuple[Tool, float]]: + """Score tools based on relevance to criteria. + + Args: + tools: List of tools to score + criteria: Selection criteria + + Returns: + List of (tool, score) tuples + """ + scored = [] + + # Prepare context for scoring + context = { + "task": criteria.task_description, + "requirements": criteria.task_description, + } + + if criteria.required_capabilities: + context["required_capabilities"] = ", ".join(criteria.required_capabilities) + + if criteria.context_hints: + context.update(criteria.context_hints) + + for tool in tools: + # Create tool info for scoring + tool_info = { + "name": tool.tool_name, + "description": tool.tool_spec.get("description", ""), + } + + # Calculate relevance score + score = self.scorer.score(tool_info, context) + + # Boost score for tools with required capabilities + if criteria.required_capabilities: + if self._has_required_capabilities(tool, criteria.required_capabilities): + score = min(1.0, score * 1.5) + + scored.append((tool, score)) + + # Record relevance score in analytics + self.analytics.record_tool_usage( + tool.tool_name, + success=True, # Just recording relevance, not actual usage + execution_time=0.0, + relevance_score=score, + ) + + # Sort by score descending + scored.sort(key=lambda x: x[1], reverse=True) + + return scored + + def _adjust_scores_by_performance(self, scored_tools: List[Tuple[Tool, float]]) -> List[Tuple[Tool, float]]: + """Adjust scores based on historical performance. + + Args: + scored_tools: List of (tool, score) tuples + + Returns: + Adjusted list of (tool, score) tuples + """ + adjusted = [] + + for tool, base_score in scored_tools: + stats = self.analytics.get_tool_stats(tool.tool_name) + + # Calculate performance multiplier + if stats.total_calls >= 5: # Minimum calls for reliable stats + performance_factor = 0.7 * stats.success_rate + 0.3 * stats.avg_relevance_score + + # Adjust score with performance factor + # Limit adjustment to ±30% + adjustment = 0.7 + (0.6 * performance_factor) + adjusted_score = base_score * adjustment + else: + # No adjustment for tools with insufficient data + adjusted_score = base_score + + adjusted.append((tool, min(1.0, adjusted_score))) + + # Re-sort by adjusted scores + adjusted.sort(key=lambda x: x[1], reverse=True) + + return adjusted + + def _filter_tools( + self, scored_tools: List[Tuple[Tool, float]], criteria: ToolSelectionCriteria + ) -> List[Tuple[Tool, float]]: + """Filter tools based on criteria. + + Args: + scored_tools: List of (tool, score) tuples + criteria: Selection criteria + + Returns: + Filtered list of (tool, score) tuples + """ + filtered = [] + + for tool, score in scored_tools: + # Check minimum score + if score < criteria.min_relevance_score: + continue + + # Check required capabilities + if criteria.required_capabilities: + if not self._has_required_capabilities(tool, criteria.required_capabilities): + continue + + filtered.append((tool, score)) + + return filtered + + def _has_required_capabilities(self, tool: Tool, required_capabilities: List[str]) -> bool: + """Check if tool has required capabilities. + + Args: + tool: Tool to check + required_capabilities: List of required capability keywords + + Returns: + True if tool has all required capabilities + """ + # Check tool name and description for capability keywords + tool_text = f"{tool.tool_name} {tool.tool_spec.get('description', '')}".lower() + + for capability in required_capabilities: + if capability.lower() not in tool_text: + return False + + return True + + def _generate_selection_reasoning( + self, selected: List[Tuple[Tool, float]], all_scored: List[Tuple[Tool, float]], criteria: ToolSelectionCriteria + ) -> Dict[str, str]: + """Generate reasoning for tool selection. + + Args: + selected: Selected tools with scores + all_scored: All scored tools + criteria: Selection criteria + + Returns: + Dictionary of reasoning by tool name + """ + reasoning = {} + + for tool, score in selected: + reasons = [] + + # Relevance reasoning + if score >= 0.8: + reasons.append(f"High relevance to task ({score:.2f})") + elif score >= 0.5: + reasons.append(f"Good relevance to task ({score:.2f})") + else: + reasons.append(f"Moderate relevance to task ({score:.2f})") + + # Performance reasoning + stats = self.analytics.get_tool_stats(tool.tool_name) + if stats.total_calls >= 5: + if stats.success_rate >= 0.9: + reasons.append(f"Excellent success rate ({stats.success_rate:.1%})") + elif stats.success_rate >= 0.7: + reasons.append(f"Good success rate ({stats.success_rate:.1%})") + + # Capability reasoning + if criteria.required_capabilities: + matching_caps = [ + cap + for cap in criteria.required_capabilities + if cap.lower() in f"{tool.tool_name} {tool.tool_spec.get('description', '')}".lower() + ] + if matching_caps: + reasons.append(f"Matches capabilities: {', '.join(matching_caps)}") + + reasoning[tool.tool_name] = "; ".join(reasons) + + return reasoning + + def get_tool_recommendations( + self, task_description: str, recent_tools: Optional[List[str]] = None, max_recommendations: int = 5 + ) -> List[Tuple[str, float]]: + """Get tool recommendations based on task and history. + + Args: + task_description: Description of the task + recent_tools: Recently used tool names + max_recommendations: Maximum recommendations to return + + Returns: + List of (tool_name, confidence) tuples + """ + recommendations = [] + + # Get performance-based recommendations + tool_rankings = self.analytics.get_tool_rankings(min_calls=3) + + for tool_name, performance_score in tool_rankings[: max_recommendations * 2]: + if tool_name in self._tool_cache: + tool = self._tool_cache[tool_name] + + # Calculate relevance to current task + relevance = self.scorer.score( + {"name": tool.tool_name, "description": tool.tool_spec.get("description", "")}, + {"task": task_description}, + ) + + # Combine performance and relevance + confidence = 0.6 * performance_score + 0.4 * relevance + + # Boost if recently used successfully + if recent_tools and tool_name in recent_tools: + confidence = min(1.0, confidence * 1.2) + + recommendations.append((tool_name, confidence)) + + # Sort by confidence and limit + recommendations.sort(key=lambda x: x[1], reverse=True) + + return recommendations[:max_recommendations] + + def update_tool_performance(self, tool_name: str, success: bool, execution_time: float) -> None: + """Update tool performance metrics after usage. + + Args: + tool_name: Name of the tool + success: Whether execution was successful + execution_time: Time taken for execution + """ + self.analytics.record_tool_usage(tool_name, success, execution_time) diff --git a/tests/strands/agent/context/__init__.py b/tests/strands/agent/context/__init__.py new file mode 100644 index 000000000..165e9a3df --- /dev/null +++ b/tests/strands/agent/context/__init__.py @@ -0,0 +1 @@ +"""Tests for intelligent context management.""" diff --git a/tests/strands/agent/context/test_analytics.py b/tests/strands/agent/context/test_analytics.py new file mode 100644 index 000000000..777b02cf7 --- /dev/null +++ b/tests/strands/agent/context/test_analytics.py @@ -0,0 +1,284 @@ +"""Tests for tool usage analytics.""" + +import time + +import pytest + +from strands.agent.context.analytics import ( + ContextPerformanceStats, + ToolUsageAnalytics, + ToolUsageStats, +) + + +class TestToolUsageStats: + """Tests for ToolUsageStats.""" + + def test_initialization(self): + """Test ToolUsageStats initialization.""" + stats = ToolUsageStats(tool_name="test_tool") + + assert stats.tool_name == "test_tool" + assert stats.total_calls == 0 + assert stats.successful_calls == 0 + assert stats.failed_calls == 0 + assert stats.total_execution_time == 0.0 + assert stats.last_used is None + assert stats.avg_relevance_score == 0.0 + assert stats.relevance_scores == [] + + def test_success_rate_calculation(self): + """Test success rate calculation.""" + stats = ToolUsageStats(tool_name="test") + + # No calls + assert stats.success_rate == 0.0 + + # Some successful calls + stats.total_calls = 10 + stats.successful_calls = 7 + assert stats.success_rate == 0.7 + + # All successful + stats.successful_calls = 10 + assert stats.success_rate == 1.0 + + def test_avg_execution_time(self): + """Test average execution time calculation.""" + stats = ToolUsageStats(tool_name="test") + + # No successful calls + assert stats.avg_execution_time == 0.0 + + # With successful calls + stats.successful_calls = 5 + stats.total_execution_time = 10.0 + assert stats.avg_execution_time == 2.0 + + def test_record_usage_success(self): + """Test recording successful usage.""" + stats = ToolUsageStats(tool_name="test") + + stats.record_usage(success=True, execution_time=1.5, relevance_score=0.8) + + assert stats.total_calls == 1 + assert stats.successful_calls == 1 + assert stats.failed_calls == 0 + assert stats.total_execution_time == 1.5 + assert stats.last_used is not None + assert stats.avg_relevance_score == 0.8 + assert stats.relevance_scores == [0.8] + + def test_record_usage_failure(self): + """Test recording failed usage.""" + stats = ToolUsageStats(tool_name="test") + + stats.record_usage(success=False, execution_time=0.5) + + assert stats.total_calls == 1 + assert stats.successful_calls == 0 + assert stats.failed_calls == 1 + assert stats.total_execution_time == 0.0 # Failures don't count + assert stats.last_used is not None + + def test_relevance_score_averaging(self): + """Test relevance score averaging.""" + stats = ToolUsageStats(tool_name="test") + + stats.record_usage(True, 1.0, relevance_score=0.8) + stats.record_usage(True, 1.0, relevance_score=0.6) + stats.record_usage(True, 1.0, relevance_score=0.7) + + assert stats.avg_relevance_score == pytest.approx(0.7, 0.01) + assert len(stats.relevance_scores) == 3 + + +class TestContextPerformanceStats: + """Tests for ContextPerformanceStats.""" + + def test_initialization(self): + """Test ContextPerformanceStats initialization.""" + stats = ContextPerformanceStats() + + assert stats.total_context_builds == 0 + assert stats.total_pruning_operations == 0 + assert stats.avg_context_size == 0.0 + assert stats.avg_pruning_ratio == 0.0 + assert stats.context_sizes == [] + assert stats.pruning_ratios == [] + + def test_record_context_build(self): + """Test recording context build operations.""" + stats = ContextPerformanceStats() + + # First build with pruning + stats.record_context_build(context_size=500, original_size=1000) + + assert stats.total_context_builds == 1 + assert stats.total_pruning_operations == 1 + assert stats.avg_context_size == 500.0 + assert stats.avg_pruning_ratio == 0.5 + + # Second build without pruning (same size) + stats.record_context_build(context_size=300, original_size=300) + + assert stats.total_context_builds == 2 + assert stats.total_pruning_operations == 2 + assert stats.avg_context_size == 400.0 # (500 + 300) / 2 + assert stats.avg_pruning_ratio == 0.25 # (0.5 + 0) / 2 + + def test_pruning_ratio_calculation(self): + """Test pruning ratio calculation.""" + stats = ContextPerformanceStats() + + # 75% reduction + stats.record_context_build(context_size=250, original_size=1000) + assert stats.pruning_ratios[-1] == 0.75 + + # No reduction + stats.record_context_build(context_size=500, original_size=500) + assert stats.pruning_ratios[-1] == 0.0 + + # Edge case: original size is 0 + stats.record_context_build(context_size=100, original_size=0) + # Should not add a pruning ratio for this case + assert len(stats.pruning_ratios) == 2 + + +class TestToolUsageAnalytics: + """Tests for ToolUsageAnalytics.""" + + def test_initialization(self): + """Test ToolUsageAnalytics initialization.""" + analytics = ToolUsageAnalytics() + + assert analytics.tool_stats == {} + assert isinstance(analytics.context_stats, ContextPerformanceStats) + assert analytics._start_time <= time.time() + + def test_get_tool_stats(self): + """Test getting or creating tool stats.""" + analytics = ToolUsageAnalytics() + + # First access creates new stats + stats1 = analytics.get_tool_stats("tool1") + assert stats1.tool_name == "tool1" + assert "tool1" in analytics.tool_stats + + # Second access returns same stats + stats2 = analytics.get_tool_stats("tool1") + assert stats1 is stats2 + + def test_record_tool_usage(self): + """Test recording tool usage.""" + analytics = ToolUsageAnalytics() + + analytics.record_tool_usage(tool_name="calculator", success=True, execution_time=0.5, relevance_score=0.9) + + stats = analytics.tool_stats["calculator"] + assert stats.total_calls == 1 + assert stats.successful_calls == 1 + assert stats.total_execution_time == 0.5 + assert stats.avg_relevance_score == 0.9 + + def test_record_context_build(self): + """Test recording context build stats.""" + analytics = ToolUsageAnalytics() + + analytics.record_context_build(context_size=500, original_size=1000) + + assert analytics.context_stats.total_context_builds == 1 + assert analytics.context_stats.avg_pruning_ratio == 0.5 + + def test_tool_rankings(self): + """Test tool performance rankings.""" + analytics = ToolUsageAnalytics() + + # Add usage data for multiple tools + for i in range(10): + analytics.record_tool_usage("tool_a", True, 1.0, 0.9) + analytics.record_tool_usage("tool_b", True, 1.0, 0.7) + analytics.record_tool_usage("tool_c", i < 5, 1.0, 0.8) # 50% success + + rankings = analytics.get_tool_rankings(min_calls=5) + + # Should be sorted by performance score + assert len(rankings) == 3 + assert rankings[0][0] == "tool_a" # Highest score + assert rankings[0][1] > rankings[1][1] # Decreasing scores + assert rankings[1][1] > rankings[2][1] + + def test_tool_rankings_min_calls_filter(self): + """Test that min_calls filter works correctly.""" + analytics = ToolUsageAnalytics() + + # tool_a: 10 calls, tool_b: 3 calls + for _ in range(10): + analytics.record_tool_usage("tool_a", True, 1.0, 0.9) + for _ in range(3): + analytics.record_tool_usage("tool_b", True, 1.0, 0.9) + + # With min_calls=5, only tool_a should be ranked + rankings = analytics.get_tool_rankings(min_calls=5) + assert len(rankings) == 1 + assert rankings[0][0] == "tool_a" + + def test_recency_factor_calculation(self): + """Test recency factor calculation.""" + analytics = ToolUsageAnalytics() + + # Recent usage + current_time = time.time() + factor_recent = analytics._calculate_recency_factor(current_time - 3600) # 1 hour ago + assert 0.9 < factor_recent <= 1.0 + + # Old usage + factor_old = analytics._calculate_recency_factor(current_time - 86400) # 24 hours ago + assert factor_old == pytest.approx(0.0, 0.1) + + # No usage + factor_none = analytics._calculate_recency_factor(None) + assert factor_none == 0.0 + + def test_summary_report(self): + """Test comprehensive summary report generation.""" + analytics = ToolUsageAnalytics() + + # Add some usage data + analytics.record_tool_usage("tool1", True, 1.0, 0.8) + analytics.record_tool_usage("tool2", False, 0.5, 0.6) + analytics.record_context_build(500, 1000) + + report = analytics.get_summary_report() + + # Check report structure + assert "uptime_seconds" in report + assert report["total_tools_used"] == 2 + assert report["total_tool_calls"] == 2 + assert report["overall_success_rate"] == 0.5 + + # Check context optimization + ctx_opt = report["context_optimization"] + assert ctx_opt["total_builds"] == 1 + assert ctx_opt["avg_context_size"] == 500.0 + assert ctx_opt["avg_pruning_ratio"] == 0.5 + + # Check top tools (may be empty due to min_calls filter) + assert "top_tools" in report + assert isinstance(report["top_tools"], list) + + def test_reset_stats(self): + """Test resetting analytics data.""" + analytics = ToolUsageAnalytics() + + # Add some data + analytics.record_tool_usage("tool1", True, 1.0) + analytics.record_context_build(500, 1000) + original_start_time = analytics._start_time + + # Reset + analytics.reset_stats() + + assert len(analytics.tool_stats) == 0 + assert analytics.context_stats.total_context_builds == 0 + assert analytics._start_time >= original_start_time diff --git a/tests/strands/agent/context/test_context_optimizer.py b/tests/strands/agent/context/test_context_optimizer.py new file mode 100644 index 000000000..d438197e0 --- /dev/null +++ b/tests/strands/agent/context/test_context_optimizer.py @@ -0,0 +1,295 @@ +"""Tests for context window optimization.""" + +from strands.agent.context.context_optimizer import ( + ContextItem, + ContextOptimizer, + ContextWindow, +) + + +class TestContextItem: + """Tests for ContextItem dataclass.""" + + def test_context_item_creation(self): + """Test creating a context item.""" + item = ContextItem( + key="test_key", + value="test value", + size=10, + relevance_score=0.8, + timestamp=1234567890.0, + metadata={"source": "test"}, + ) + + assert item.key == "test_key" + assert item.value == "test value" + assert item.size == 10 + assert item.relevance_score == 0.8 + assert item.timestamp == 1234567890.0 + assert item.metadata == {"source": "test"} + + def test_context_item_defaults(self): + """Test context item with default values.""" + item = ContextItem(key="test", value="value", size=5) + + assert item.relevance_score == 0.0 + assert item.timestamp == 0.0 + assert item.metadata is None + + +class TestContextWindow: + """Tests for ContextWindow dataclass.""" + + def test_context_window_creation(self): + """Test creating a context window.""" + items = [ContextItem("key1", "value1", 10), ContextItem("key2", "value2", 20)] + + window = ContextWindow(items=items, total_size=30, max_size=100, optimization_stats={"items_removed": 5}) + + assert len(window.items) == 2 + assert window.total_size == 30 + assert window.max_size == 100 + assert window.optimization_stats["items_removed"] == 5 + + def test_utilization_calculation(self): + """Test context window utilization calculation.""" + window = ContextWindow(items=[], total_size=75, max_size=100, optimization_stats={}) + + assert window.utilization == 0.75 + + # Edge case: max_size is 0 + window_zero = ContextWindow(items=[], total_size=50, max_size=0, optimization_stats={}) + assert window_zero.utilization == 0.0 + + def test_to_dict_conversion(self): + """Test converting context window to dictionary.""" + items = [ContextItem("key1", "value1", 10), ContextItem("key2", {"nested": "value"}, 20)] + + window = ContextWindow(items, 30, 100, {}) + result = window.to_dict() + + assert result == {"key1": "value1", "key2": {"nested": "value"}} + + +class TestContextOptimizer: + """Tests for ContextOptimizer.""" + + def test_initialization(self): + """Test ContextOptimizer initialization.""" + optimizer = ContextOptimizer(max_context_size=4096, relevance_threshold=0.4) + + assert optimizer.max_context_size == 4096 + assert optimizer.relevance_threshold == 0.4 + assert optimizer.scorer is not None + assert optimizer.relevance_filter is not None + + def test_size_estimation(self): + """Test default size estimation.""" + optimizer = ContextOptimizer() + + # String estimation (4 chars ≈ 1 token) + assert optimizer._estimate_size("hello world") == 2 # 11 chars / 4 + assert optimizer._estimate_size("a" * 100) == 25 + + # Dict estimation + dict_size = optimizer._estimate_size({"key": "value"}) + assert dict_size > 0 + + # List estimation + list_size = optimizer._estimate_size([1, 2, 3, 4, 5]) + assert list_size > 0 + + # Other types + assert optimizer._estimate_size(42) > 0 + + def test_optimize_context_basic(self): + """Test basic context optimization.""" + optimizer = ContextOptimizer(max_context_size=50) + + context_items = { + "relevant1": "This is about machine learning", + "relevant2": "Machine learning algorithms", + "irrelevant": "Weather forecast for tomorrow", + } + + task = "explain machine learning concepts" + + result = optimizer.optimize_context(context_items, task) + + assert isinstance(result, ContextWindow) + assert len(result.items) <= len(context_items) + assert result.total_size <= result.max_size + + # Check that relevant items are included + keys = [item.key for item in result.items] + assert "relevant1" in keys or "relevant2" in keys + + def test_required_keys_inclusion(self): + """Test that required keys are always included.""" + optimizer = ContextOptimizer( + max_context_size=20, + relevance_threshold=0.9, # Very high threshold + ) + + context_items = {"required": "must be included", "optional1": "relevant content", "optional2": "also relevant"} + + result = optimizer.optimize_context(context_items, "some task", required_keys=["required"]) + + # Required key must be in result + keys = [item.key for item in result.items] + assert "required" in keys + + def test_size_constraints(self): + """Test that size constraints are respected.""" + optimizer = ContextOptimizer(max_context_size=10) # Very small + + context_items = { + f"item{i}": "x" * 20 # Each item ~5 tokens + for i in range(10) + } + + result = optimizer.optimize_context(context_items, "task") + + # Total size must not exceed max + assert result.total_size <= optimizer.max_context_size + assert len(result.items) < len(context_items) # Some items excluded + + def test_relevance_filtering(self): + """Test relevance threshold filtering.""" + optimizer = ContextOptimizer(max_context_size=1000, relevance_threshold=0.5) + + context_items = { + "high_relevance": "machine learning model training", + "low_relevance": "random unrelated content", + } + + task = "train a machine learning model" + + result = optimizer.optimize_context(context_items, task) + + # Check relevance scores + for item in result.items: + if item.key == "low_relevance": + assert item.relevance_score < optimizer.relevance_threshold + + def test_optimization_stats(self): + """Test optimization statistics generation.""" + optimizer = ContextOptimizer() + + context_items = {f"item{i}": f"content {i}" for i in range(5)} + + result = optimizer.optimize_context(context_items, "content 1") + + stats = result.optimization_stats + assert "original_items" in stats + assert "optimized_items" in stats + assert "original_size" in stats + assert "optimized_size" in stats + assert "pruning_ratio" in stats + assert "avg_relevance" in stats + + assert stats["original_items"] == 5 + assert stats["optimized_items"] <= 5 + + def test_item_compression(self): + """Test item compression for oversized items.""" + optimizer = ContextOptimizer(max_context_size=50) + + # Create an item that's too large + large_item = ContextItem( + key="large", + value="x" * 1000, # ~250 tokens + size=250, + relevance_score=0.9, + ) + + compressed = optimizer._try_compress_item(large_item, 20) + + assert compressed is not None + assert compressed.size <= 20 + assert compressed.value.endswith("...") + assert compressed.metadata["truncated"] is True + + def test_merge_contexts(self): + """Test merging multiple context windows.""" + optimizer = ContextOptimizer() + + # Create two context windows + window1 = ContextWindow( + items=[ + ContextItem("key1", "value1", 10, relevance_score=0.8), + ContextItem("shared", "value_old", 15, relevance_score=0.6), + ], + total_size=25, + max_size=100, + optimization_stats={}, + ) + + window2 = ContextWindow( + items=[ + ContextItem("key2", "value2", 20, relevance_score=0.7), + ContextItem("shared", "value_new", 15, relevance_score=0.9), + ], + total_size=35, + max_size=100, + optimization_stats={}, + ) + + merged = optimizer.merge_contexts([window1, window2], "merge task") + + # Should have 3 unique keys + keys = [item.key for item in merged.items] + assert len(set(keys)) <= 3 + + # For duplicate keys, higher relevance should win + shared_items = [item for item in merged.items if item.key == "shared"] + if shared_items: + assert shared_items[0].value == "value_new" # Higher relevance + + def test_pruning_recommendations(self): + """Test getting pruning recommendations.""" + optimizer = ContextOptimizer() + + window = ContextWindow( + items=[ + ContextItem("low_rel", "value", 10, relevance_score=0.3), + ContextItem("large", "value", 300, relevance_score=0.8), + ContextItem("truncated", "val...", 5, relevance_score=0.6, metadata={"truncated": True}), + ], + total_size=315, + max_size=1000, + optimization_stats={}, + ) + + recommendations = optimizer.get_pruning_recommendations(window) + + # Should have recommendations for each problematic item + assert len(recommendations) == 3 + + # Check recommendation types + rec_keys = [r[0] for r in recommendations] + assert "low_rel" in rec_keys # Low relevance + assert "large" in rec_keys # Large size + assert "truncated" in rec_keys # Was truncated + + def test_custom_size_estimator(self): + """Test using custom size estimator.""" + optimizer = ContextOptimizer() + + # Custom estimator that counts words + def word_count_estimator(value): + return len(str(value).split()) + + context_items = { + "short": "hello world", # 2 words + "long": "this is a much longer sentence with more words", # 10 words + } + + result = optimizer.optimize_context(context_items, "task", size_estimator=word_count_estimator) + + # Check that custom estimator was used + for item in result.items: + if item.key == "short": + assert item.size == 2 + elif item.key == "long": + assert item.size == 10 diff --git a/tests/strands/agent/context/test_performance.py b/tests/strands/agent/context/test_performance.py new file mode 100644 index 000000000..c1c1edcb5 --- /dev/null +++ b/tests/strands/agent/context/test_performance.py @@ -0,0 +1,250 @@ +"""Performance tests for context management.""" + +import time +from typing import Any, Dict, List + +import pytest + +from strands.agent.context import ( + ContextOptimizer, + DynamicToolManager, + ToolSelectionCriteria, +) +from strands.agent.context.relevance_scoring import TextRelevanceScorer +from strands.types.tools import AgentTool, ToolSpec, ToolUse + + +class MockTool(AgentTool): + """Mock tool for performance testing.""" + + def __init__(self, name: str, description: str): + super().__init__() + self._name = name + self._description = description + + @property + def tool_name(self) -> str: + return self._name + + @property + def tool_spec(self) -> ToolSpec: + return { + "name": self._name, + "description": self._description, + "inputSchema": {"type": "object", "properties": {}}, + } + + @property + def tool_type(self) -> str: + return "mock" + + async def stream(self, tool_use: ToolUse, invocation_state: Dict[str, Any], **kwargs): + """Mock stream implementation.""" + yield {"type": "result", "result": "mock"} + + +class TestContextOptimizerPerformance: + """Performance tests for ContextOptimizer.""" + + @pytest.mark.parametrize("num_items", [100, 500, 1000]) + def test_optimization_speed(self, num_items: int): + """Test context optimization speed with varying sizes.""" + optimizer = ContextOptimizer(max_context_size=1000) + + # Create large context + context_items = { + f"item_{i}": f"This is context item {i} with some content about {i % 10}" for i in range(num_items) + } + + task = "Find items about content 5" + + start_time = time.time() + result = optimizer.optimize_context(context_items, task) + optimization_time = time.time() - start_time + + # Performance assertions + assert optimization_time < 1.0 # Should complete within 1 second + assert result.total_size <= optimizer.max_context_size + + # Verify optimization worked + pruning_ratio = result.optimization_stats["pruning_ratio"] + if num_items > 100: + assert pruning_ratio > 0 # Should prune something for large contexts + + print(f"Optimized {num_items} items in {optimization_time:.3f}s (pruned {pruning_ratio:.1%})") + + def test_relevance_scoring_performance(self): + """Test relevance scoring performance.""" + scorer = TextRelevanceScorer() + + # Create test data + items = [f"Item {i} with content about topic {i % 20}" for i in range(1000)] + context = "Looking for items about topic 15" + + start_time = time.time() + scores = [scorer.score(item, context) for item in items] + scoring_time = time.time() - start_time + + assert scoring_time < 0.5 # 1000 items in 0.5s + assert len(scores) == 1000 + assert all(0 <= s <= 1 for s in scores) + + print(f"Scored 1000 items in {scoring_time:.3f}s ({1000 / scoring_time:.0f} items/sec)") + + def test_large_context_compression(self): + """Test performance of context compression for large items.""" + optimizer = ContextOptimizer(max_context_size=500) + + # Create context with large items + large_text = "x" * 10000 # Very large item + context_items = {"large_item": large_text, "normal_item": "Normal sized content", "another_large": "y" * 5000} + + start_time = time.time() + result = optimizer.optimize_context( + context_items, + "task requiring all items", + required_keys=["large_item"], # Force inclusion of large item + ) + compression_time = time.time() - start_time + + assert compression_time < 0.1 # Should be fast + assert result.total_size <= optimizer.max_context_size + + # Check that large item was compressed + large_items = [item for item in result.items if item.key == "large_item"] + if large_items: + assert large_items[0].value != large_text # Should be compressed + assert large_items[0].value.endswith("...") + + +class TestToolManagerPerformance: + """Performance tests for DynamicToolManager.""" + + def create_mock_tools(self, count: int) -> List[AgentTool]: + """Create mock tools for testing.""" + categories = ["file", "data", "web", "system", "analysis", "compute"] + actions = ["read", "write", "process", "analyze", "fetch", "transform"] + + tools = [] + for i in range(count): + category = categories[i % len(categories)] + action = actions[i % len(actions)] + name = f"{category}_{action}_tool_{i}" + description = f"A tool that {action}s {category} data and performs operations" + tools.append(MockTool(name, description)) + + return tools + + @pytest.mark.parametrize("num_tools", [50, 100, 500]) + def test_tool_selection_speed(self, num_tools: int): + """Test tool selection speed with varying numbers of tools.""" + manager = DynamicToolManager() + tools = self.create_mock_tools(num_tools) + + criteria = ToolSelectionCriteria(task_description="analyze data from files and web sources", max_tools=20) + + start_time = time.time() + result = manager.select_tools(tools, criteria) + selection_time = time.time() - start_time + + # Performance assertions + assert selection_time < 0.5 # Should complete quickly + assert len(result.selected_tools) <= criteria.max_tools + assert result.selection_time > 0 + + print(f"Selected from {num_tools} tools in {selection_time:.3f}s") + + def test_tool_scoring_with_history(self): + """Test tool selection performance with usage history.""" + manager = DynamicToolManager() + tools = self.create_mock_tools(100) + + # Simulate usage history + for i in range(50): + tool_name = tools[i].tool_name + success = i % 3 != 0 # 2/3 success rate + manager.update_tool_performance(tool_name, success, 0.1 * (i % 5)) + + criteria = ToolSelectionCriteria(task_description="process and analyze data", prefer_recent=True) + + start_time = time.time() + result = manager.select_tools(tools, criteria) + selection_time = time.time() - start_time + + assert selection_time < 0.5 + assert len(result.selected_tools) > 0 + + # Tools with history should be considered + selected_names = [t.tool_name for t in result.selected_tools] + tools_with_history = [tools[i].tool_name for i in range(50)] + overlap = set(selected_names) & set(tools_with_history) + assert len(overlap) > 0 # Some tools with history should be selected + + def test_recommendations_performance(self): + """Test performance of tool recommendations.""" + manager = DynamicToolManager() + tools = self.create_mock_tools(200) + manager._update_tool_cache(tools) + + # Add usage history for subset of tools + for i in range(100): + tool = tools[i] + for _ in range(5): # Minimum calls for ranking + manager.update_tool_performance(tool.tool_name, success=True, execution_time=0.1) + + start_time = time.time() + recommendations = manager.get_tool_recommendations(task_description="analyze web data", max_recommendations=10) + rec_time = time.time() - start_time + + assert rec_time < 0.1 # Should be very fast + assert len(recommendations) <= 10 + + print(f"Generated {len(recommendations)} recommendations in {rec_time:.3f}s") + + +class TestIntegratedPerformance: + """Integration performance tests.""" + + def test_full_context_optimization_pipeline(self): + """Test full pipeline from tool selection to context optimization.""" + # Setup + tool_manager = DynamicToolManager() + context_optimizer = ContextOptimizer(max_context_size=2000) + + # Create tools and context + tools = [] + for i in range(100): + # Create tools with more relevant descriptions + category = ["analyze", "process", "extract", "transform", "filter"][i % 5] + topic = i % 10 + tools.append(MockTool(f"{category}_tool_{i}", f"Tool that {category}s data related to topic {topic}")) + + context_items = {f"ctx_{i}": f"Context data {i} related to {i % 10}" for i in range(500)} + + task = "Analyze data items related to topic 5" + + # Measure full pipeline + start_time = time.time() + + # Step 1: Select tools + tool_criteria = ToolSelectionCriteria(task_description=task, max_tools=10) + tool_result = tool_manager.select_tools(tools, tool_criteria) + + # Step 2: Optimize context + context_result = context_optimizer.optimize_context( + context_items, + task, + required_keys=["ctx_5", "ctx_15"], # Require some specific items + ) + + total_time = time.time() - start_time + + # Performance assertions + assert total_time < 1.0 # Full pipeline under 1 second + assert len(tool_result.selected_tools) > 0 + assert context_result.total_size <= context_optimizer.max_context_size + + print(f"Full pipeline completed in {total_time:.3f}s:") + print(f" - Selected {len(tool_result.selected_tools)} tools") + print(f" - Optimized context from {len(context_items)} to {len(context_result.items)} items") + print(f" - Context reduction: {context_result.optimization_stats['pruning_ratio']:.1%}") diff --git a/tests/strands/agent/context/test_relevance_scoring.py b/tests/strands/agent/context/test_relevance_scoring.py new file mode 100644 index 000000000..f492f768a --- /dev/null +++ b/tests/strands/agent/context/test_relevance_scoring.py @@ -0,0 +1,208 @@ +"""Tests for relevance scoring functionality.""" + +from strands.agent.context.relevance_scoring import ( + ContextRelevanceFilter, + ScoredItem, + SimilarityMetric, + TextRelevanceScorer, + ToolRelevanceScorer, +) + + +class TestTextRelevanceScorer: + """Tests for TextRelevanceScorer.""" + + def test_jaccard_similarity_identical(self): + """Test Jaccard similarity for identical texts.""" + scorer = TextRelevanceScorer(metric=SimilarityMetric.JACCARD) + score = scorer.score("hello world", "hello world") + assert score == 1.0 + + def test_jaccard_similarity_partial(self): + """Test Jaccard similarity for partially overlapping texts.""" + scorer = TextRelevanceScorer(metric=SimilarityMetric.JACCARD) + score = scorer.score("hello world", "hello universe") + # Intersection: {hello}, Union: {hello, world, universe} + # Score: 1/3 ≈ 0.333 + assert 0.3 < score < 0.4 + + def test_jaccard_similarity_no_overlap(self): + """Test Jaccard similarity for non-overlapping texts.""" + scorer = TextRelevanceScorer(metric=SimilarityMetric.JACCARD) + score = scorer.score("hello world", "foo bar") + assert score == 0.0 + + def test_jaccard_similarity_empty(self): + """Test Jaccard similarity with empty strings.""" + scorer = TextRelevanceScorer(metric=SimilarityMetric.JACCARD) + assert scorer.score("", "") == 1.0 + assert scorer.score("hello", "") == 0.0 + assert scorer.score("", "world") == 0.0 + + def test_levenshtein_similarity_identical(self): + """Test Levenshtein similarity for identical texts.""" + scorer = TextRelevanceScorer(metric=SimilarityMetric.LEVENSHTEIN) + score = scorer.score("hello", "hello") + assert score == 1.0 + + def test_levenshtein_similarity_one_edit(self): + """Test Levenshtein similarity with one character difference.""" + scorer = TextRelevanceScorer(metric=SimilarityMetric.LEVENSHTEIN) + score = scorer.score("hello", "hallo") + # 1 edit in 5 characters = 0.8 similarity + assert 0.79 < score < 0.81 + + def test_text_conversion(self): + """Test conversion of different types to text.""" + scorer = TextRelevanceScorer() + + # Dict conversion + dict_score = scorer.score({"key": "value", "number": 42}, '{"key": "value"}') + assert dict_score > 0 + + # List conversion + list_score = scorer.score([1, 2, 3], "[1, 2, 3]") + assert list_score > 0 + + # Number conversion + number_score = scorer.score(42, "42") + assert number_score > 0 + + +class TestToolRelevanceScorer: + """Tests for ToolRelevanceScorer.""" + + def test_tool_scoring_by_name(self): + """Test tool scoring based on name matching.""" + scorer = ToolRelevanceScorer() + + tool_info = {"name": "file_reader", "description": "Reads files from the filesystem"} + + context = {"task": "I need to read a file", "requirements": "file reading capability"} + + score = scorer.score(tool_info, context) + assert score > 0.1 # Should have some relevance due to partial matches + + def test_tool_scoring_by_description(self): + """Test tool scoring based on description matching.""" + scorer = ToolRelevanceScorer() + + tool_info = {"name": "tool_x", "description": "Performs database queries and data analysis"} + + context = {"task": "analyze data from database", "requirements": "need to query and analyze data"} + + score = scorer.score(tool_info, context) + assert score > 0.1 # Should have some relevance due to partial matches + + def test_required_tools_max_score(self): + """Test that required tools get maximum score.""" + scorer = ToolRelevanceScorer() + + tool_info = {"name": "calculator", "description": "Basic math operations"} + + context = {"task": "perform calculations", "required_tools": ["calculator", "converter"]} + + score = scorer.score(tool_info, context) + assert score == 1.0 # Required tool should get max score + + def test_tool_object_extraction(self): + """Test extraction from tool objects.""" + scorer = ToolRelevanceScorer() + + # Mock tool object + class MockTool: + name = "test_tool" + description = "A test tool" + parameters = {"param1": "string"} + + context = {"task": "test something"} + score = scorer.score(MockTool(), context) + assert score > 0 # Should be able to score tool object + + +class TestContextRelevanceFilter: + """Tests for ContextRelevanceFilter.""" + + def test_filter_by_min_score(self): + """Test filtering items by minimum score.""" + scorer = TextRelevanceScorer() + filter = ContextRelevanceFilter(scorer) + + items = {"item1": "hello world", "item2": "foo bar", "item3": "hello universe"} + + context = "hello world programming" + + filtered = filter.filter_relevant(items, context, min_score=0.2) + + # Should include item1 and item3 (both contain "hello") + assert len(filtered) >= 2 + assert any(item.key == "item1" for item in filtered) + assert any(item.key == "item3" for item in filtered) + + # item2 should have low/zero score + item2_scores = [item.score for item in filtered if item.key == "item2"] + if item2_scores: + assert item2_scores[0] < 0.2 + + def test_filter_max_items(self): + """Test limiting number of returned items.""" + scorer = TextRelevanceScorer() + filter = ContextRelevanceFilter(scorer) + + items = {f"item{i}": f"test content {i}" for i in range(10)} + context = "test content" + + filtered = filter.filter_relevant(items, context, min_score=0.0, max_items=3) + + assert len(filtered) == 3 + + def test_filter_sorting(self): + """Test that items are sorted by relevance score.""" + scorer = TextRelevanceScorer() + filter = ContextRelevanceFilter(scorer) + + items = {"exact": "hello world", "partial": "hello", "none": "foo bar"} + + context = "hello world" + + filtered = filter.filter_relevant(items, context, min_score=0.0) + + # Should be sorted by score descending + assert filtered[0].key == "exact" # Highest score + assert filtered[0].score > filtered[1].score + if len(filtered) > 2: + assert filtered[1].score >= filtered[2].score + + def test_get_top_k(self): + """Test getting top-k items.""" + scorer = TextRelevanceScorer() + filter = ContextRelevanceFilter(scorer) + + items = {f"item{i}": f"content {i % 3}" for i in range(10)} + context = "content 1" + + top_3 = filter.get_top_k(items, context, k=3) + + assert len(top_3) == 3 + # All should have scores (even if 0) + assert all(hasattr(item, "score") for item in top_3) + # Should be sorted by score + assert top_3[0].score >= top_3[1].score >= top_3[2].score + + +class TestScoredItem: + """Tests for ScoredItem dataclass.""" + + def test_scored_item_creation(self): + """Test creating a scored item.""" + item = ScoredItem(key="test_key", value={"data": "test"}, score=0.75, metadata={"source": "test"}) + + assert item.key == "test_key" + assert item.value == {"data": "test"} + assert item.score == 0.75 + assert item.metadata == {"source": "test"} + + def test_scored_item_defaults(self): + """Test scored item with default metadata.""" + item = ScoredItem(key="test", value="value", score=0.5) + assert item.metadata is None diff --git a/tests/strands/agent/context/test_tool_manager.py b/tests/strands/agent/context/test_tool_manager.py new file mode 100644 index 000000000..e622efd3d --- /dev/null +++ b/tests/strands/agent/context/test_tool_manager.py @@ -0,0 +1,323 @@ +"""Tests for dynamic tool management.""" + +from typing import Any, Dict + +from strands.agent.context.analytics import ToolUsageAnalytics +from strands.agent.context.tool_manager import ( + DynamicToolManager, + ToolSelectionCriteria, + ToolSelectionResult, +) +from strands.types.tools import AgentTool, ToolSpec, ToolUse + + +class MockTool(AgentTool): + """Mock tool for testing.""" + + def __init__(self, name: str, description: str = None): + super().__init__() + self._name = name + self._description = description or f"A tool that {name}" + + @property + def tool_name(self) -> str: + return self._name + + @property + def tool_spec(self) -> ToolSpec: + return { + "name": self._name, + "description": self._description, + "inputSchema": {"type": "object", "properties": {}}, + } + + @property + def tool_type(self) -> str: + return "mock" + + async def stream(self, tool_use: ToolUse, invocation_state: Dict[str, Any], **kwargs): + """Mock stream implementation.""" + yield {"type": "result", "result": "mock"} + + +class TestToolSelectionCriteria: + """Tests for ToolSelectionCriteria dataclass.""" + + def test_criteria_creation(self): + """Test creating selection criteria.""" + criteria = ToolSelectionCriteria( + task_description="analyze data", + required_capabilities=["data", "analysis"], + excluded_tools={"dangerous_tool"}, + max_tools=15, + min_relevance_score=0.4, + prefer_recent=False, + context_hints={"domain": "finance"}, + ) + + assert criteria.task_description == "analyze data" + assert criteria.required_capabilities == ["data", "analysis"] + assert "dangerous_tool" in criteria.excluded_tools + assert criteria.max_tools == 15 + assert criteria.min_relevance_score == 0.4 + assert criteria.prefer_recent is False + assert criteria.context_hints["domain"] == "finance" + + def test_criteria_defaults(self): + """Test default values for criteria.""" + criteria = ToolSelectionCriteria(task_description="test task") + + assert criteria.required_capabilities is None + assert criteria.excluded_tools is None + assert criteria.max_tools == 20 + assert criteria.min_relevance_score == 0.2 + assert criteria.prefer_recent is True + assert criteria.context_hints is None + + +class TestDynamicToolManager: + """Tests for DynamicToolManager.""" + + def test_initialization(self): + """Test DynamicToolManager initialization.""" + manager = DynamicToolManager() + + assert manager.analytics is not None + assert isinstance(manager.analytics, ToolUsageAnalytics) + assert manager.scorer is not None + assert manager._tool_cache == {} + + def test_tool_selection_basic(self): + """Test basic tool selection.""" + manager = DynamicToolManager() + + tools = [ + MockTool("file_reader", "Reads files from disk and file system for data processing"), + MockTool("file_writer", "Writes files to disk"), + MockTool("calculator", "Performs calculations and computes statistics on data"), + MockTool("web_scraper", "Scrapes web pages"), + ] + + criteria = ToolSelectionCriteria( + task_description="I need to read a file and calculate some statistics", + min_relevance_score=0.1, # Lower threshold for testing + ) + + result = manager.select_tools(tools, criteria) + + assert isinstance(result, ToolSelectionResult) + assert len(result.selected_tools) > 0 + assert len(result.selected_tools) <= criteria.max_tools + + # Should include relevant tools + tool_names = [t.tool_name for t in result.selected_tools] + assert "file_reader" in tool_names or "calculator" in tool_names + + def test_excluded_tools_filtering(self): + """Test that excluded tools are filtered out.""" + manager = DynamicToolManager() + + tools = [MockTool("safe_tool"), MockTool("dangerous_tool"), MockTool("another_tool")] + + criteria = ToolSelectionCriteria(task_description="any task", excluded_tools={"dangerous_tool"}) + + result = manager.select_tools(tools, criteria) + + tool_names = [t.tool_name for t in result.selected_tools] + assert "dangerous_tool" not in tool_names + + def test_max_tools_limit(self): + """Test that max_tools limit is respected.""" + manager = DynamicToolManager() + + # Create many tools + tools = [MockTool(f"tool_{i}") for i in range(50)] + + criteria = ToolSelectionCriteria(task_description="use all tools", max_tools=5) + + result = manager.select_tools(tools, criteria) + + assert len(result.selected_tools) <= 5 + + def test_minimum_relevance_filtering(self): + """Test filtering by minimum relevance score.""" + manager = DynamicToolManager() + + tools = [ + MockTool("very_relevant_tool", "exactly what the task needs"), + MockTool("unrelated_tool", "something completely different"), + ] + + criteria = ToolSelectionCriteria( + task_description="exactly what the task needs", + min_relevance_score=0.7, # High threshold + ) + + result = manager.select_tools(tools, criteria) + + # Only highly relevant tools should be selected + for tool in result.selected_tools: + assert result.relevance_scores[tool.tool_name] >= 0.7 + + def test_required_capabilities_filtering(self): + """Test filtering by required capabilities.""" + manager = DynamicToolManager() + + tools = [ + MockTool("data_analyzer", "analyzes data and generates reports"), + MockTool("file_reader", "reads files"), + MockTool("data_visualizer", "creates data visualizations"), + ] + + criteria = ToolSelectionCriteria( + task_description="analyze some data", + required_capabilities=["data", "analyz"], # Partial match + ) + + result = manager.select_tools(tools, criteria) + + # Should include tools with "data" and "analyz" in name/description + tool_names = [t.tool_name for t in result.selected_tools] + assert "data_analyzer" in tool_names + + def test_performance_based_adjustment(self): + """Test performance-based score adjustment.""" + analytics = ToolUsageAnalytics() + manager = DynamicToolManager(analytics=analytics) + + # Record good performance for tool1 + for _ in range(10): + analytics.record_tool_usage("tool1", True, 0.5, 0.8) + + # Record poor performance for tool2 + for _ in range(10): + analytics.record_tool_usage("tool2", False, 1.0, 0.8) + + tools = [ + MockTool("tool1", "description"), + MockTool("tool2", "description"), # Same description + ] + + criteria = ToolSelectionCriteria(task_description="description", prefer_recent=True) + + result = manager.select_tools(tools, criteria) + + # tool1 should have higher adjusted score due to better performance + if len(result.selected_tools) > 0: + assert result.selected_tools[0].tool_name == "tool1" + + def test_selection_reasoning_generation(self): + """Test that selection reasoning is generated.""" + manager = DynamicToolManager() + + tools = [ + MockTool("high_relevance_tool", "exactly matches the task"), + MockTool("low_relevance_tool", "unrelated functionality"), + ] + + criteria = ToolSelectionCriteria(task_description="exactly matches the task") + + result = manager.select_tools(tools, criteria) + + # Should have reasoning for selected tools + for tool in result.selected_tools: + assert tool.tool_name in result.selection_reasoning + reasoning = result.selection_reasoning[tool.tool_name] + assert len(reasoning) > 0 + assert "relevance" in reasoning.lower() + + def test_tool_recommendations(self): + """Test getting tool recommendations.""" + analytics = ToolUsageAnalytics() + manager = DynamicToolManager(analytics=analytics) + + # Add some tools to cache + tools = [MockTool("frequently_used", "common tool"), MockTool("rarely_used", "uncommon tool")] + manager._update_tool_cache(tools) + + # Record usage history + for _ in range(10): + analytics.record_tool_usage("frequently_used", True, 0.5, 0.9) + + recommendations = manager.get_tool_recommendations(task_description="common task", max_recommendations=3) + + assert len(recommendations) <= 3 + # Recommendations should be (tool_name, confidence) tuples + if recommendations: + assert len(recommendations[0]) == 2 + assert isinstance(recommendations[0][0], str) + assert isinstance(recommendations[0][1], float) + + def test_recent_tools_boost(self): + """Test that recently used tools get a confidence boost.""" + manager = DynamicToolManager() + + # Add tools to cache + tools = [MockTool("tool1"), MockTool("tool2")] + manager._update_tool_cache(tools) + + # Record some usage + manager.analytics.record_tool_usage("tool1", True, 0.5, 0.5) + manager.analytics.record_tool_usage("tool2", True, 0.5, 0.5) + + # Get recommendations with recent tools + recommendations = manager.get_tool_recommendations(task_description="any task", recent_tools=["tool1"]) + + # tool1 should have higher confidence due to recency boost + rec_dict = dict(recommendations) + if "tool1" in rec_dict and "tool2" in rec_dict: + assert rec_dict["tool1"] > rec_dict["tool2"] + + def test_update_tool_performance(self): + """Test updating tool performance metrics.""" + manager = DynamicToolManager() + + # Update performance + manager.update_tool_performance("test_tool", success=True, execution_time=1.5) + + # Check that it was recorded + stats = manager.analytics.get_tool_stats("test_tool") + assert stats.total_calls == 1 + assert stats.successful_calls == 1 + assert stats.total_execution_time == 1.5 + + def test_selection_result_metadata(self): + """Test that selection result contains proper metadata.""" + manager = DynamicToolManager() + + tools = [MockTool(f"tool_{i}") for i in range(10)] + criteria = ToolSelectionCriteria(task_description="test") + + result = manager.select_tools(tools, criteria) + + assert result.total_candidates == 10 + assert result.selection_time > 0 + assert isinstance(result.relevance_scores, dict) + assert isinstance(result.selection_reasoning, dict) + + def test_no_tools_scenario(self): + """Test behavior when no tools are available.""" + manager = DynamicToolManager() + + criteria = ToolSelectionCriteria(task_description="any task") + result = manager.select_tools([], criteria) + + assert len(result.selected_tools) == 0 + assert result.total_candidates == 0 + assert result.relevance_scores == {} + + def test_all_tools_below_threshold(self): + """Test when all tools are below relevance threshold.""" + manager = DynamicToolManager() + + tools = [MockTool("unrelated1", "completely unrelated"), MockTool("unrelated2", "also unrelated")] + + criteria = ToolSelectionCriteria( + task_description="specific database query optimization", + min_relevance_score=0.9, # Very high threshold + ) + + result = manager.select_tools(tools, criteria) + + # No tools should be selected + assert len(result.selected_tools) == 0