diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index ead290a35..0328dc8b8 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -3,8 +3,9 @@ This package includes an abstract base Model class along with concrete implementations for specific providers. """ -from . import bedrock, model +from . import bedrock, fallback, model from .bedrock import BedrockModel +from .fallback import FallbackModel from .model import Model -__all__ = ["bedrock", "model", "BedrockModel", "Model"] +__all__ = ["bedrock", "fallback", "model", "BedrockModel", "FallbackModel", "Model"] diff --git a/src/strands/models/fallback.py b/src/strands/models/fallback.py new file mode 100644 index 000000000..e3ce63f3a --- /dev/null +++ b/src/strands/models/fallback.py @@ -0,0 +1,871 @@ +"""FallbackModel implementation for automatic failover between models. + +This module provides the FallbackModel class, which wraps two Model instances (primary and fallback) +and automatically switches to the fallback model when the primary model fails with retryable errors. +The implementation is provider-agnostic and works with any combination of Strands model types. + +Key Features: +- **Automatic Failover**: Switches to fallback on throttling, connection, and network errors +- **Circuit Breaker Pattern**: Temporarily skips failing primary model to prevent cascading failures +- **Provider Agnostic**: Works with any combination of model providers (OpenAI→Bedrock, etc.) +- **Configurable Behavior**: Customizable thresholds, error detection, and statistics tracking +- **Full Model Interface**: Supports both streaming and structured output methods +- **Comprehensive Logging**: Detailed logging for debugging and monitoring + +Circuit Breaker Behavior: +The circuit breaker monitors primary model failures within a sliding time window. When the failure +threshold is exceeded, it "opens" and routes all requests directly to the fallback model for a +cooldown period. This prevents wasting time and resources on a consistently failing primary model. + +Circuit States: +- **Closed** (default): Attempt primary model for each request +- **Open**: Skip primary model, use fallback model directly +- **Half-Open**: After cooldown, next request tests if primary has recovered + +Error Classification: +By default, the following errors trigger fallback: +- ModelThrottledException (rate limiting) +- Connection errors (network, timeout, refused, unavailable, etc.) +- Custom errors via the should_fallback configuration function + +Non-retryable errors (like ContextWindowOverflowException) are re-raised without fallback. + +Example usage: + ```python + from strands.models import FallbackModel, BedrockModel, OpenAIModel + + # Same-provider fallback (different model sizes) + model = FallbackModel( + primary=BedrockModel(model_id="claude-3-opus"), + fallback=BedrockModel(model_id="claude-3-haiku"), + circuit_failure_threshold=3, + circuit_time_window=60.0, + circuit_cooldown_seconds=30 + ) + + # Cross-provider fallback + cross_provider = FallbackModel( + primary=OpenAIModel(model_id="gpt-4"), + fallback=BedrockModel(model_id="claude-3-sonnet"), + circuit_failure_threshold=5, + circuit_time_window=120.0 + ) + + # Use with an agent + from strands.agent import Agent + agent = Agent(model=model) + response = agent.run("Hello!") + + # Monitor fallback statistics + stats = model.get_stats() + print(f"Fallback count: {stats['fallback_count']}") + print(f"Circuit open: {stats['circuit_open']}") + ``` + +""" + +import logging +import time +from collections import deque +from typing import Any, AsyncGenerator, Callable, Optional, Type, TypeVar, Union + +from pydantic import BaseModel +from typing_extensions import TypedDict, Unpack, override + +from ..types.content import Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolChoice, ToolSpec +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class FallbackConfig(TypedDict, total=False): + """Configuration for FallbackModel. + + This TypedDict defines the optional configuration parameters for controlling + the FallbackModel's behavior, including circuit breaker settings and error + detection logic. + + Attributes: + circuit_failure_threshold: Number of primary model failures within the time + window before the circuit breaker opens. Once opened, the primary model + will be skipped until the cooldown period expires. Default: 3 + circuit_time_window: Time window in seconds for counting failures. Only + failures within this window are counted toward the threshold. Default: 60.0 + circuit_cooldown_seconds: How long to wait in seconds before retrying the + primary model after the circuit breaker opens. Default: 30 + should_fallback: Optional custom function that takes an Exception and returns + a boolean indicating whether to attempt fallback. If provided, this function + overrides the default error classification logic. Default: None + track_stats: Whether to track statistics about fallback usage, including + fallback count, primary failures, and circuit breaker state. Default: True + """ + + circuit_failure_threshold: int + circuit_time_window: float + circuit_cooldown_seconds: int + should_fallback: Optional[Callable[[Exception], bool]] + track_stats: bool + + +class FallbackStats(TypedDict): + """Typed structure for fallback statistics to provide better type hints. + + This TypedDict defines the structure returned by get_stats() method, providing + comprehensive information about the FallbackModel's current state and usage + statistics for monitoring and debugging purposes. + + Attributes: + fallback_count: Total number of times the fallback model was used due to + primary model failures + primary_failures: Total number of primary model failures encountered + circuit_skips: Number of times requests were routed directly to fallback + because the circuit breaker was open + using_fallback: Boolean indicating whether the last request used the fallback + model (True) or primary model (False) + circuit_open: Boolean indicating whether the circuit breaker is currently + open, meaning primary model will be skipped + recent_failures: Number of primary model failures within the current time + window that count toward opening the circuit breaker + circuit_open_until: Timestamp (float) when the circuit breaker will close + and primary model will be retried, or None if circuit is closed + primary_model_name: Human-readable name/identifier of the primary model + for debugging and monitoring purposes + fallback_model_name: Human-readable name/identifier of the fallback model + for debugging and monitoring purposes + """ + + fallback_count: int + primary_failures: int + circuit_skips: int + using_fallback: bool + circuit_open: bool + recent_failures: int + circuit_open_until: Optional[float] + primary_model_name: str + fallback_model_name: str + + +class FallbackModel(Model): + """A model that automatically falls back to a secondary model on primary model failures. + + FallbackModel wraps two Model instances (primary and fallback) and provides automatic + failover when the primary model encounters retryable errors such as throttling, + connection issues, or network problems. + + FallbackModel implements a circuit breaker pattern to prevent repeated attempts + to a failing primary model. The circuit breaker opens after a configurable number + of failures within a time window, temporarily routing all requests to the fallback + model until the cooldown period expires. + + Example: + ```python + from strands.models import FallbackModel, BedrockModel + + # Create a fallback model with two Bedrock models + model = FallbackModel( + primary=BedrockModel(model_id="claude-3-opus"), + fallback=BedrockModel(model_id="claude-3-haiku"), + circuit_failure_threshold=3, + circuit_time_window=60.0, + circuit_cooldown_seconds=30 + ) + + # Use with an agent + agent = Agent(model=model) + response = agent.run("Hello!") + ``` + + Attributes: + primary: The primary Model instance to use for requests + fallback: The fallback Model instance to use when primary fails + circuit_failure_threshold: Number of failures before circuit opens + circuit_time_window: Time window in seconds for counting failures + circuit_cooldown_seconds: Cooldown period before retrying primary + should_fallback: Optional custom function for error classification + track_stats: Whether to track usage statistics + """ + + def __init__( + self, + *, + primary: Model, + fallback: Model, + **config: Unpack[FallbackConfig], + ) -> None: + """Initialize the FallbackModel with primary and fallback models. + + Args: + primary: The primary Model instance to use for requests + fallback: The fallback Model instance to use when primary fails + **config: Configuration options from FallbackConfig TypedDict: + - circuit_failure_threshold: Number of failures before circuit opens (default: 3) + - circuit_time_window: Time window in seconds for counting failures (default: 60.0) + - circuit_cooldown_seconds: Cooldown period in seconds (default: 30) + - should_fallback: Optional custom error classification function (default: None) + - track_stats: Whether to track statistics (default: True) + """ + # Store model instances + self.primary = primary + self.fallback = fallback + + # Initialize config with defaults + self.circuit_failure_threshold = config.get("circuit_failure_threshold", 3) + self.circuit_time_window = config.get("circuit_time_window", 60.0) + self.circuit_cooldown_seconds = config.get("circuit_cooldown_seconds", 30) + self.should_fallback = config.get("should_fallback", None) + self.track_stats = config.get("track_stats", True) + + # Initialize circuit breaker state + self._failure_timestamps: deque[float] = deque(maxlen=100) + self._circuit_open = False + self._circuit_open_until: Optional[float] = None + + # Initialize statistics + self._stats: dict[str, Union[int, bool]] = { + "fallback_count": 0, + "primary_failures": 0, + "circuit_skips": 0, + "using_fallback": False, + } + + logger.info( + "primary=<%s>, fallback=<%s>, circuit_failure_threshold=<%d>, " + "circuit_time_window=<%s>s, circuit_cooldown_seconds=<%d>s | initialized FallbackModel", + self._get_model_name(primary), + self._get_model_name(fallback), + self.circuit_failure_threshold, + self.circuit_time_window, + self.circuit_cooldown_seconds, + ) + + def _check_circuit(self) -> bool: + """Check if the circuit breaker is open and handle cooldown expiration. + + This method checks the current state of the circuit breaker. If the circuit + is open and the cooldown period has expired, it automatically closes the + circuit and logs the event. + + Returns: + True if the circuit is open (primary should be skipped), False if closed + (primary can be attempted). + """ + current_time = time.time() + + # Check if circuit is open + if self._circuit_open: + # Check if cooldown has expired + if self._circuit_open_until is not None and current_time >= self._circuit_open_until: + # Close the circuit + self._circuit_open = False + self._circuit_open_until = None + logger.info("Circuit breaker closed, will retry primary model") + return False + + # Circuit is still open + return True + + # Circuit is closed + return False + + def _handle_primary_failure(self, error: Exception) -> None: + """Handle a primary model failure and potentially open the circuit breaker. + + This method records the failure timestamp, updates statistics, and checks + if the circuit breaker threshold has been reached. If the threshold is + exceeded, the circuit breaker opens and remains open for the configured + cooldown period. + + Args: + error: The exception that caused the primary model to fail. + """ + current_time = time.time() + + # Record the failure timestamp + self._failure_timestamps.append(current_time) + + # Increment failure counter + if self.track_stats: + self._stats["primary_failures"] += 1 + + # Count recent failures within the time window + recent_failures = sum( + 1 for timestamp in self._failure_timestamps if current_time - timestamp <= self.circuit_time_window + ) + + # Check if we should open the circuit + if recent_failures >= self.circuit_failure_threshold: + self._circuit_open = True + self._circuit_open_until = current_time + self.circuit_cooldown_seconds + logger.warning( + "recent_failures=<%d>, time_window=<%s>s, cooldown=<%d>s | circuit breaker opened", + recent_failures, + self.circuit_time_window, + self.circuit_cooldown_seconds, + ) + + def _should_fallback(self, error: Exception) -> bool: + """Determine if an error should trigger fallback to the secondary model. + + This method classifies errors to determine if they are retryable with a + fallback model. By default, it triggers fallback for throttling and + connection/network errors, but not for context window overflow errors. + + A custom classification function can be provided via the should_fallback + configuration parameter to override the default logic. + + Args: + error: The exception raised by the primary model. + + Returns: + True if the error should trigger fallback, False otherwise. + """ + # Check if custom should_fallback function exists in config + if self.should_fallback is not None: + return self.should_fallback(error) + + # Check if error is instance of ModelThrottledException + if isinstance(error, ModelThrottledException): + return True + + # Convert error to lowercase string + error_str = str(error).lower() + + # Check if any connection error keywords present in error string + connection_keywords = [ + "connection", + "network", + "timeout", + "refused", + "unavailable", + "closed", + "aborted", + "reset", + ] + if any(keyword in error_str for keyword in connection_keywords): + return True + + # Check if error is instance of ContextWindowOverflowException + if isinstance(error, ContextWindowOverflowException): + return False + + # Return False as default + return False + + def _get_model_name(self, model: Model) -> str: + """Extract a human-readable name for the model for debugging purposes. + + This method attempts to extract a meaningful identifier from the model's + configuration, falling back to the class name if no specific identifier + is found. This helps with debugging and monitoring by providing clear + model identification in statistics and logs. + + Args: + model: The Model instance to extract a name from. + + Returns: + A string identifier for the model, either from configuration or class name. + """ + try: + # Try to get from model configuration first + config = model.get_config() + if isinstance(config, dict): + # Look for common model identifier fields + for key in ["model_id", "model", "name"]: + if key in config and config[key]: + return str(config[key]) + except Exception: + # If get_config() fails, fall back to class name + pass + + # Fall back to class name + return model.__class__.__name__ + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + *, + tool_choice: ToolChoice | None = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the model, with automatic fallback on primary failure. + + This method attempts to stream from the primary model first. If the primary model + fails with a retryable error (throttling, connection issues), it automatically + falls back to the secondary model. The circuit breaker may skip the primary + model entirely if it has failed repeatedly. + + The method handles the following scenarios: + 1. Circuit breaker open: Skip primary and use fallback directly + 2. Primary success: Stream from primary model + 3. Primary failure (retryable): Fall back to secondary model + 4. Primary failure (non-retryable): Re-raise the exception + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + **kwargs: Additional keyword arguments passed to the underlying model. + + Yields: + StreamEvent objects from either the primary or fallback model. + + Raises: + Exception: Re-raises non-retryable exceptions from the primary model, + or exceptions from the fallback model if both models fail. + + Example: + ```python + model = FallbackModel(primary=primary_model, fallback=fallback_model) + + async for event in model.stream(messages=[{"role": "user", "content": "Hello"}]): + print(event) + ``` + """ + # Check if circuit breaker is open + if self._check_circuit(): + # Circuit is open, skip primary and use fallback directly + if self.track_stats: + self._stats["circuit_skips"] += 1 + self._stats["using_fallback"] = True + + logger.info( + "fallback_model=<%s> | circuit breaker is open, using fallback model directly", + self._get_model_name(self.fallback), + ) + + # Yield events from fallback model + async for event in self.fallback.stream( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + **kwargs, + ): + yield event + + return + + # Circuit is closed, try primary model + if self.track_stats: + self._stats["using_fallback"] = False + + try: + # Attempt to stream from primary model + async for event in self.primary.stream( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + **kwargs, + ): + yield event + + # Primary model succeeded + return + + except Exception as error: + # Primary model failed + error_message = str(error)[:200] # Truncate to first 200 chars + logger.warning( + "primary_model=<%s>, error_type=<%s>, error_message=<%s> | primary model failed", + self._get_model_name(self.primary), + error.__class__.__name__, + error_message, + ) + + # Check if we should fallback for this error + if not self._should_fallback(error): + logger.debug("error_type=<%s> | error is not fallback-eligible, re-raising", error.__class__.__name__) + raise + + # Error is fallback-eligible, handle the failure + self._handle_primary_failure(error) + + if self.track_stats: + self._stats["fallback_count"] += 1 + self._stats["using_fallback"] = True + + logger.info( + "fallback_model=<%s>, fallback_count=<%d> | attempting fallback", + self._get_model_name(self.fallback), + self._stats["fallback_count"], + ) + + try: + # Attempt to stream from fallback model + async for event in self.fallback.stream( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + **kwargs, + ): + yield event + + logger.info("fallback_model=<%s> | fallback succeeded", self._get_model_name(self.fallback)) + + except Exception as fallback_error: + # Both models failed + logger.error( + "primary_model=<%s>, primary_error=<%s>, fallback_model=<%s>, fallback_error=<%s> | " + "both models failed", + self._get_model_name(self.primary), + error.__class__.__name__, + self._get_model_name(self.fallback), + fallback_error.__class__.__name__, + ) + # Raise the fallback exception, not the primary + raise fallback_error + + @override + async def structured_output( + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Generate structured output with the model, with automatic fallback on primary failure. + + This method attempts to generate structured output from the primary model first. If the + primary model fails with a retryable error (throttling, connection issues), it automatically + falls back to the secondary model. The circuit breaker may skip the primary model entirely + if it has failed repeatedly. + + The method handles the following scenarios: + 1. Circuit breaker open: Skip primary and use fallback directly + 2. Primary success: Generate structured output from primary model + 3. Primary failure (retryable): Fall back to secondary model + 4. Primary failure (non-retryable): Re-raise the exception + + Args: + output_model: Pydantic model class defining the expected output structure. + prompt: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments passed to the underlying model. + + Yields: + Dictionary events from either the primary or fallback model, containing + structured output data conforming to the output_model schema. + + Raises: + Exception: Re-raises non-retryable exceptions from the primary model, + or exceptions from the fallback model if both models fail. + + Example: + ```python + from pydantic import BaseModel + + class Response(BaseModel): + answer: str + confidence: float + + model = FallbackModel(primary=primary_model, fallback=fallback_model) + + async for event in model.structured_output( + output_model=Response, + prompt=[{"role": "user", "content": "What is 2+2?"}] + ): + if event.get("chunk_type") == "structured_output": + print(event["data"]) + ``` + """ + # Check if circuit breaker is open + if self._check_circuit(): + # Circuit is open, skip primary and use fallback directly + if self.track_stats: + self._stats["circuit_skips"] += 1 + self._stats["using_fallback"] = True + + logger.info( + "fallback_model=<%s> | circuit breaker is open, using fallback model directly for structured output", + self._get_model_name(self.fallback), + ) + + # Yield events from fallback model + async for event in self.fallback.structured_output( + output_model=output_model, + prompt=prompt, + system_prompt=system_prompt, + **kwargs, + ): + yield event + + return + + # Circuit is closed, try primary model + if self.track_stats: + self._stats["using_fallback"] = False + + try: + # Attempt to get structured output from primary model + async for event in self.primary.structured_output( + output_model=output_model, + prompt=prompt, + system_prompt=system_prompt, + **kwargs, + ): + yield event + + # Primary model succeeded + return + + except Exception as error: + # Primary model failed + error_message = str(error)[:200] # Truncate to first 200 chars + logger.warning( + "primary_model=<%s>, error_type=<%s>, error_message=<%s> | " + "primary model failed during structured output", + self._get_model_name(self.primary), + error.__class__.__name__, + error_message, + ) + + # Check if we should fallback for this error + if not self._should_fallback(error): + logger.debug("error_type=<%s> | error is not fallback-eligible, re-raising", error.__class__.__name__) + raise + + # Error is fallback-eligible, handle the failure + self._handle_primary_failure(error) + + if self.track_stats: + self._stats["fallback_count"] += 1 + self._stats["using_fallback"] = True + + logger.info( + "fallback_model=<%s>, fallback_count=<%d> | attempting fallback for structured output", + self._get_model_name(self.fallback), + self._stats["fallback_count"], + ) + + try: + # Attempt to get structured output from fallback model + async for event in self.fallback.structured_output( + output_model=output_model, + prompt=prompt, + system_prompt=system_prompt, + **kwargs, + ): + yield event + + logger.info( + "fallback_model=<%s> | fallback succeeded for structured output", + self._get_model_name(self.fallback), + ) + + except Exception as fallback_error: + # Both models failed + logger.error( + "primary_model=<%s>, primary_error=<%s>, fallback_model=<%s>, fallback_error=<%s> | " + "both models failed during structured output", + self._get_model_name(self.primary), + error.__class__.__name__, + self._get_model_name(self.fallback), + fallback_error.__class__.__name__, + ) + # Raise the fallback exception, not the primary + raise fallback_error + + @override + def update_config(self, **config: Any) -> None: + """Update the FallbackModel configuration. + + This method updates the configuration parameters for the FallbackModel itself, + such as circuit breaker thresholds and error detection logic. It does not + affect the configuration of the underlying primary or fallback models. + + Args: + **config: Configuration options from FallbackConfig TypedDict: + - circuit_failure_threshold: Number of failures before circuit opens + - circuit_time_window: Time window in seconds for counting failures + - circuit_cooldown_seconds: Cooldown period in seconds + - should_fallback: Optional custom error classification function + - track_stats: Whether to track statistics + + Example: + ```python + model = FallbackModel(primary=primary_model, fallback=fallback_model) + + # Update circuit breaker thresholds + model.update_config( + circuit_failure_threshold=5, + circuit_time_window=120.0 + ) + ``` + + Note: + This method only updates the FallbackModel's configuration. To update + the configuration of the underlying primary or fallback models, call + their respective update_config() methods directly. + """ + # Update circuit_failure_threshold if provided + if "circuit_failure_threshold" in config: + self.circuit_failure_threshold = config["circuit_failure_threshold"] + + # Update circuit_time_window if provided + if "circuit_time_window" in config: + self.circuit_time_window = config["circuit_time_window"] + + # Update circuit_cooldown_seconds if provided + if "circuit_cooldown_seconds" in config: + self.circuit_cooldown_seconds = config["circuit_cooldown_seconds"] + + # Update should_fallback if provided + if "should_fallback" in config: + self.should_fallback = config["should_fallback"] + + # Update track_stats if provided + if "track_stats" in config: + self.track_stats = config["track_stats"] + + def get_stats(self) -> FallbackStats: + """Get current statistics about fallback usage and circuit breaker state. + + This method returns comprehensive statistics including fallback counts, + primary failures, circuit breaker state, recent failure counts within + the configured time window, and model identifiers for debugging. + + Returns: + FallbackStats containing: + - fallback_count: Total number of times fallback was used + - primary_failures: Total number of primary model failures + - circuit_skips: Number of times circuit breaker skipped primary + - using_fallback: Whether currently using fallback model + - circuit_open: Current circuit breaker state (True if open) + - recent_failures: Number of failures within the time window + - circuit_open_until: Timestamp when circuit will close (or None) + - primary_model_name: Name/identifier of the primary model + - fallback_model_name: Name/identifier of the fallback model + + Example: + ```python + model = FallbackModel(primary=primary_model, fallback=fallback_model) + + # After some usage + stats = model.get_stats() + print(f"Fallback count: {stats['fallback_count']}") + print(f"Circuit open: {stats['circuit_open']}") + print(f"Primary model: {stats['primary_model_name']}") + print(f"Fallback model: {stats['fallback_model_name']}") + ``` + """ + # Get current timestamp + current_time = time.time() + + # Count recent failures within time window using list comprehension + recent_failures = sum( + 1 for timestamp in self._failure_timestamps if current_time - timestamp <= self.circuit_time_window + ) + + # Return FallbackStats with all required fields + return FallbackStats( + fallback_count=int(self._stats["fallback_count"]), + primary_failures=int(self._stats["primary_failures"]), + circuit_skips=int(self._stats["circuit_skips"]), + using_fallback=bool(self._stats["using_fallback"]), + circuit_open=self._circuit_open, + recent_failures=recent_failures, + circuit_open_until=self._circuit_open_until, + primary_model_name=self._get_model_name(self.primary), + fallback_model_name=self._get_model_name(self.fallback), + ) + + def reset_stats(self) -> None: + """Reset all statistics and circuit breaker state. + + This method clears all tracked statistics, failure timestamps, and resets + the circuit breaker to its initial closed state. This is useful for testing + or when you want to start fresh after resolving issues with the primary model. + + Example: + ```python + model = FallbackModel(primary=primary_model, fallback=fallback_model) + + # After some usage and failures + model.reset_stats() + + # Statistics are now cleared + stats = model.get_stats() + assert stats['fallback_count'] == 0 + assert stats['circuit_open'] == False + ``` + + Note: + This method resets the FallbackModel's statistics and circuit breaker + state, but does not affect the underlying primary or fallback models. + """ + # Clear _failure_timestamps deque + self._failure_timestamps.clear() + + # Set _circuit_open=False and _circuit_open_until=None + self._circuit_open = False + self._circuit_open_until = None + + # Reset _stats dict to initial values + self._stats = { + "fallback_count": 0, + "primary_failures": 0, + "circuit_skips": 0, + "using_fallback": False, + } + + @override + def get_config(self) -> dict[str, Any]: + """Get the complete configuration including FallbackModel and underlying models. + + This method returns a comprehensive view of the configuration, including: + - FallbackModel's own configuration (circuit breaker settings, etc.) + - Primary model's configuration + - Fallback model's configuration + - Current statistics (if tracking is enabled) + + Returns: + Dictionary with keys: + - fallback_config: FallbackModel configuration parameters + - primary_config: Configuration from the primary model + - fallback_model_config: Configuration from the fallback model + - stats: Current statistics (if track_stats is True, otherwise None) + + Example: + ```python + model = FallbackModel(primary=primary_model, fallback=fallback_model) + + config = model.get_config() + print(f"Circuit threshold: {config['fallback_config']['circuit_failure_threshold']}") + print(f"Fallback count: {config['stats']['fallback_count']}") + ``` + """ + # Build fallback_config dictionary + fallback_config = { + "circuit_failure_threshold": self.circuit_failure_threshold, + "circuit_time_window": self.circuit_time_window, + "circuit_cooldown_seconds": self.circuit_cooldown_seconds, + "should_fallback": self.should_fallback, + "track_stats": self.track_stats, + } + + # Get primary model config + primary_config = self.primary.get_config() + + # Get fallback model config + fallback_model_config = self.fallback.get_config() + + # Get stats if tracking is enabled + stats = self.get_stats() if self.track_stats else None + + return { + "fallback_config": fallback_config, + "primary_config": primary_config, + "fallback_model_config": fallback_model_config, + "stats": stats, + } diff --git a/tests/strands/models/test_fallback.py b/tests/strands/models/test_fallback.py new file mode 100644 index 000000000..f06bb66fa --- /dev/null +++ b/tests/strands/models/test_fallback.py @@ -0,0 +1,1268 @@ +"""Unit tests for FallbackModel.""" + +from unittest.mock import AsyncMock + +import pytest + +from strands.models import Model +from strands.models.fallback import FallbackModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def mock_primary(): + """Create a mock primary model.""" + model = AsyncMock(spec=Model) + model.get_config.return_value = {"model": "primary"} + return model + + +@pytest.fixture +def mock_fallback(): + """Create a mock fallback model.""" + model = AsyncMock(spec=Model) + model.get_config.return_value = {"model": "fallback"} + return model + + +@pytest.fixture +def fallback_model(mock_primary, mock_fallback): + """Create a FallbackModel instance with mock models.""" + return FallbackModel(primary=mock_primary, fallback=mock_fallback) + + +async def async_generator(items): + """Helper function to create async generators for test data.""" + for item in items: + yield item + + +@pytest.mark.asyncio +async def test_primary_success(mock_primary, mock_fallback, fallback_model, alist): + """Test that primary model success returns primary results without fallback.""" + # Mock primary.stream to return test events + test_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_primary.stream.return_value = async_generator(test_events) + + # Call stream + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + result = await alist(fallback_model.stream(messages)) + + # Assert all events received + assert result == test_events + + # Assert fallback was not called + mock_fallback.stream.assert_not_called() + + # Assert fallback_count is 0 + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + + +@pytest.mark.asyncio +async def test_throttle_exception(mock_primary, mock_fallback, fallback_model, alist): + """Test fallback triggers on ModelThrottledException.""" + # Mock primary.stream to raise ModelThrottledException + mock_primary.stream.side_effect = ModelThrottledException("Rate limit exceeded") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + # Call stream + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + result = await alist(fallback_model.stream(messages)) + + # Assert fallback events received + assert result == fallback_events + + # Assert fallback_count is 1 + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 1 + + +@pytest.mark.asyncio +async def test_connection_error(mock_primary, mock_fallback, fallback_model, alist): + """Test fallback triggers on connection errors.""" + # Mock primary.stream to raise exception with "connection timeout" message + mock_primary.stream.side_effect = Exception("connection timeout error") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + # Call stream + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + result = await alist(fallback_model.stream(messages)) + + # Assert fallback events received + assert result == fallback_events + + # Assert fallback was called + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 1 + + +@pytest.mark.asyncio +async def test_context_overflow_no_fallback(mock_primary, mock_fallback, fallback_model): + """Test no fallback on ContextWindowOverflowException.""" + # Mock primary.stream to raise ContextWindowOverflowException + mock_primary.stream.side_effect = ContextWindowOverflowException("Context window exceeded") + + # Call stream and expect exception to be re-raised + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + with pytest.raises(ContextWindowOverflowException): + async for _ in fallback_model.stream(messages): + pass + + # Assert fallback was not called + mock_fallback.stream.assert_not_called() + + # Assert fallback_count is 0 + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + + +@pytest.mark.asyncio +async def test_both_models_fail(mock_primary, mock_fallback, fallback_model): + """Test that fallback exception is raised when both fail.""" + # Mock primary.stream to raise ModelThrottledException + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to raise RuntimeError + mock_fallback.stream.side_effect = RuntimeError("Fallback failed") + + # Call stream and expect RuntimeError (not ModelThrottledException) + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + with pytest.raises(RuntimeError, match="Fallback failed"): + async for _ in fallback_model.stream(messages): + pass + + +# Circuit Breaker Tests + + +@pytest.mark.asyncio +async def test_circuit_opens_after_threshold_failures(mock_primary, mock_fallback, alist): + """Test circuit opens after threshold failures.""" + # Configure FallbackModel with circuit_failure_threshold=2 + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=2, circuit_time_window=60.0 + ) + + # Mock primary.stream to raise ModelThrottledException + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + # Trigger first failure + await alist(fallback_model.stream(messages)) + stats = fallback_model.get_stats() + assert stats["circuit_open"] is False + assert stats["primary_failures"] == 1 + + # Trigger second failure - should open circuit + await alist(fallback_model.stream(messages)) + stats = fallback_model.get_stats() + assert stats["circuit_open"] is True + assert stats["primary_failures"] == 2 + + # Reset mock call count to verify next request skips primary + mock_primary.stream.reset_mock() + mock_fallback.stream.reset_mock() + + # Next request should skip primary and go directly to fallback + await alist(fallback_model.stream(messages)) + + # Assert primary was not called (circuit is open) + mock_primary.stream.assert_not_called() + + # Assert fallback was called directly + mock_fallback.stream.assert_called_once() + + # Assert circuit_skips counter increased + stats = fallback_model.get_stats() + assert stats["circuit_skips"] == 1 + + +@pytest.mark.asyncio +async def test_circuit_stays_closed_below_threshold(mock_primary, mock_fallback, alist): + """Test circuit stays closed below threshold.""" + # Configure with threshold=3 + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=3, circuit_time_window=60.0 + ) + + # Mock primary.stream to raise ModelThrottledException + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + # Trigger 2 failures (below threshold of 3) + await alist(fallback_model.stream(messages)) + await alist(fallback_model.stream(messages)) + + # Assert circuit_open is False + stats = fallback_model.get_stats() + assert stats["circuit_open"] is False + assert stats["primary_failures"] == 2 + + # Reset mock call count to verify primary is still attempted + mock_primary.stream.reset_mock() + mock_fallback.stream.reset_mock() + + # Next request should still attempt primary (circuit is closed) + await alist(fallback_model.stream(messages)) + + # Assert primary was called (circuit is still closed) + mock_primary.stream.assert_called_once() + + # Assert fallback was also called due to primary failure + mock_fallback.stream.assert_called_once() + + +@pytest.mark.asyncio +async def test_circuit_closes_after_cooldown(mock_primary, mock_fallback, alist): + """Test circuit closes after cooldown period.""" + import time + from unittest.mock import patch + + # Configure with circuit_cooldown_seconds=1 + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=2, circuit_cooldown_seconds=1 + ) + + # Mock primary.stream to raise ModelThrottledException initially + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + # Open circuit with 2 failures + await alist(fallback_model.stream(messages)) + await alist(fallback_model.stream(messages)) + + # Assert circuit is open + stats = fallback_model.get_stats() + assert stats["circuit_open"] is True + + # Mock time.time to advance past cooldown + current_time = time.time() + with patch("time.time", return_value=current_time + 2): # 2 seconds later + # Reset mocks to track new calls + mock_primary.stream.reset_mock() + mock_fallback.stream.reset_mock() + + # Now make primary succeed to test circuit closing + primary_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Primary response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_primary.stream.side_effect = None + mock_primary.stream.return_value = async_generator(primary_events) + + # Make request - circuit should close and primary should be retried + result = await alist(fallback_model.stream(messages)) + + # Assert circuit closed and primary was retried + stats = fallback_model.get_stats() + assert stats["circuit_open"] is False + + # Assert primary was called (circuit closed) + mock_primary.stream.assert_called_once() + + # Assert fallback was not called (primary succeeded) + mock_fallback.stream.assert_not_called() + + # Assert we got primary response + assert result == primary_events + + +@pytest.mark.asyncio +async def test_time_window_failure_counting(mock_primary, mock_fallback, alist): + """Test only failures within time window count.""" + import time + from unittest.mock import patch + + # Configure with circuit_time_window=2 + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=2, circuit_time_window=2.0 + ) + + # Mock primary.stream to raise ModelThrottledException + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + current_time = time.time() + + # Add old failure timestamp (> 2 seconds ago) by mocking time + with patch("time.time", return_value=current_time - 3): # 3 seconds ago + await alist(fallback_model.stream(messages)) + + # Add recent failure with current time + with patch("time.time", return_value=current_time): + await alist(fallback_model.stream(messages)) + + # Check stats - should show 2 total failures but only 1 recent failure + with patch("time.time", return_value=current_time): + stats = fallback_model.get_stats() + assert stats["primary_failures"] == 2 # Total failures + assert stats["recent_failures"] == 1 # Only recent failure counts + assert stats["circuit_open"] is False # Circuit should stay closed (only 1 recent failure < threshold of 2) + + # Add another recent failure to test circuit opening with recent failures only + with patch("time.time", return_value=current_time + 0.5): # Still within 2-second window + await alist(fallback_model.stream(messages)) + + stats = fallback_model.get_stats() + assert stats["primary_failures"] == 3 # Total failures + assert stats["recent_failures"] == 2 # Two recent failures within window + assert stats["circuit_open"] is True # Circuit should open (2 recent failures >= threshold of 2) + + +# Configuration Tests + + +def test_default_config(mock_primary, mock_fallback): + """Test default configuration values are applied.""" + # Create FallbackModel without config + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback) + + # Assert default values + assert fallback_model.circuit_failure_threshold == 3 + assert fallback_model.circuit_time_window == 60.0 + assert fallback_model.circuit_cooldown_seconds == 30 + assert fallback_model.track_stats is True + assert fallback_model.should_fallback is None + + +def test_custom_config(mock_primary, mock_fallback): + """Test custom configuration is applied.""" + # Create FallbackModel with custom values + custom_config = { + "circuit_failure_threshold": 5, + "circuit_time_window": 120.0, + "circuit_cooldown_seconds": 60, + "track_stats": False, + } + + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback, **custom_config) + + # Assert custom values are used + assert fallback_model.circuit_failure_threshold == 5 + assert fallback_model.circuit_time_window == 120.0 + assert fallback_model.circuit_cooldown_seconds == 60 + assert fallback_model.track_stats is False + + +@pytest.mark.asyncio +async def test_custom_should_fallback(mock_primary, mock_fallback, alist): + """Test custom should_fallback function is used.""" + + # Create custom function that returns True for specific error + def custom_should_fallback(error): + return isinstance(error, ValueError) and "custom_error" in str(error) + + # Create FallbackModel with custom should_fallback function + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback, should_fallback=custom_should_fallback) + + # Mock primary to raise that error + mock_primary.stream.side_effect = ValueError("custom_error occurred") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + # Call stream + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + result = await alist(fallback_model.stream(messages)) + + # Assert fallback is triggered + assert result == fallback_events + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 1 + + # Test that other ValueError doesn't trigger fallback + mock_primary.stream.reset_mock() + mock_fallback.stream.reset_mock() + mock_primary.stream.side_effect = ValueError("different error") + + # This should not trigger fallback and should re-raise + with pytest.raises(ValueError, match="different error"): + async for _ in fallback_model.stream(messages): + pass + + # Assert fallback was not called + mock_fallback.stream.assert_not_called() + + # Assert fallback_count didn't increase + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 1 # Still 1 from previous test + + +def test_update_config(mock_primary, mock_fallback): + """Test configuration can be updated.""" + # Create FallbackModel + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback) + + # Verify initial default values + assert fallback_model.circuit_failure_threshold == 3 + assert fallback_model.circuit_time_window == 60.0 + + # Call update_config with new values + fallback_model.update_config(circuit_failure_threshold=10, circuit_time_window=300.0, circuit_cooldown_seconds=120) + + # Assert config is updated + assert fallback_model.circuit_failure_threshold == 10 + assert fallback_model.circuit_time_window == 300.0 + assert fallback_model.circuit_cooldown_seconds == 120 + + # Assert other values remain unchanged + assert fallback_model.track_stats is True # Default value preserved + + # Test partial update + fallback_model.update_config(circuit_failure_threshold=7) + + # Assert only specified value is updated + assert fallback_model.circuit_failure_threshold == 7 + assert fallback_model.circuit_time_window == 300.0 # Previous value preserved + assert fallback_model.circuit_cooldown_seconds == 120 # Previous value preserved + + +# Statistics Tests + + +@pytest.mark.asyncio +async def test_statistics_tracking(mock_primary, mock_fallback, alist): + """Test statistics are tracked correctly.""" + # Create FallbackModel with circuit_failure_threshold=2 for easier testing + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=2, track_stats=True + ) + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + # Mock primary success events + primary_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Primary response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + # Test initial stats + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert stats["primary_failures"] == 0 + assert stats["circuit_skips"] == 0 + assert stats["using_fallback"] is False + assert stats["circuit_open"] is False + assert stats["recent_failures"] == 0 + assert stats["circuit_open_until"] is None + + # Scenario 1: Primary success + mock_primary.stream.return_value = async_generator(primary_events) + result = await alist(fallback_model.stream(messages)) + + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert stats["primary_failures"] == 0 + assert stats["circuit_skips"] == 0 + assert stats["using_fallback"] is False + assert stats["circuit_open"] is False + assert result == primary_events + + # Scenario 2: Primary failure, fallback success + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + result = await alist(fallback_model.stream(messages)) + + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 1 + assert stats["primary_failures"] == 1 + assert stats["circuit_skips"] == 0 + assert stats["using_fallback"] is True # Last operation used fallback + assert stats["circuit_open"] is False # Still below threshold + assert stats["recent_failures"] == 1 + assert result == fallback_events + + # Scenario 3: Another primary failure, should open circuit + result = await alist(fallback_model.stream(messages)) + + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 2 + assert stats["primary_failures"] == 2 + assert stats["circuit_skips"] == 0 + assert stats["using_fallback"] is True + assert stats["circuit_open"] is True # Circuit opened after 2 failures + assert stats["recent_failures"] == 2 + assert stats["circuit_open_until"] is not None + + # Scenario 4: Circuit skip (circuit is open) + mock_primary.stream.reset_mock() + mock_fallback.stream.reset_mock() + mock_fallback.stream.return_value = async_generator(fallback_events) + + result = await alist(fallback_model.stream(messages)) + + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 2 # Doesn't increment for circuit skips + assert stats["primary_failures"] == 2 # No new primary failure + assert stats["circuit_skips"] == 1 # Circuit skip counter increased + assert stats["using_fallback"] is True + assert stats["circuit_open"] is True + + # Verify primary was not called (circuit skip) + mock_primary.stream.assert_not_called() + mock_fallback.stream.assert_called_once() + assert result == fallback_events + + +def test_stats_disabled(mock_primary, mock_fallback): + """Test statistics can be disabled.""" + # Create FallbackModel with track_stats=False + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback, track_stats=False) + + # Call get_config() + config = fallback_model.get_config() + + # Assert stats is None + assert config["stats"] is None + + # Verify other config sections are present + assert "fallback_config" in config + assert "primary_config" in config + assert "fallback_model_config" in config + + # Verify fallback_config contains track_stats=False + assert config["fallback_config"]["track_stats"] is False + + +@pytest.mark.asyncio +async def test_reset_stats(mock_primary, mock_fallback, alist): + """Test reset_stats clears all statistics.""" + # Create FallbackModel with circuit_failure_threshold=2 for easier testing + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=2, track_stats=True + ) + + # Mock primary.stream to raise ModelThrottledException + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to return test events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + # Trigger failures to populate stats and open circuit + await alist(fallback_model.stream(messages)) # First failure + await alist(fallback_model.stream(messages)) # Second failure - opens circuit + + # Verify stats are populated and circuit is open + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 2 + assert stats["primary_failures"] == 2 + assert stats["circuit_open"] is True + assert stats["recent_failures"] == 2 + assert stats["circuit_open_until"] is not None + + # Trigger a circuit skip to populate circuit_skips counter + await alist(fallback_model.stream(messages)) # Circuit skip + + stats = fallback_model.get_stats() + assert stats["circuit_skips"] == 1 + assert stats["using_fallback"] is True + + # Call reset_stats() + fallback_model.reset_stats() + + # Assert all counters are 0 and circuit is closed + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert stats["primary_failures"] == 0 + assert stats["circuit_skips"] == 0 + assert stats["using_fallback"] is False + assert stats["circuit_open"] is False + assert stats["recent_failures"] == 0 + assert stats["circuit_open_until"] is None + + # Verify circuit is actually closed by checking that primary is attempted again + mock_primary.stream.reset_mock() + mock_fallback.stream.reset_mock() + + # This should attempt primary again (circuit is closed) + await alist(fallback_model.stream(messages)) + + # Verify primary was called (circuit is closed) + mock_primary.stream.assert_called_once() + mock_fallback.stream.assert_called_once() # Called due to primary failure + + +def test_get_config(mock_primary, mock_fallback): + """Test get_config returns all configuration.""" + + # Create FallbackModel with custom config + def custom_should_fallback(error): + return True + + fallback_model = FallbackModel( + primary=mock_primary, + fallback=mock_fallback, + circuit_failure_threshold=5, + circuit_time_window=120.0, + circuit_cooldown_seconds=60, + should_fallback=custom_should_fallback, + track_stats=True, + ) + + # Call get_config() + config = fallback_model.get_config() + + # Assert returns fallback_config, primary_config, fallback_model_config, stats + assert "fallback_config" in config + assert "primary_config" in config + assert "fallback_model_config" in config + assert "stats" in config + + # Verify fallback_config contains all expected fields + fallback_config = config["fallback_config"] + assert fallback_config["circuit_failure_threshold"] == 5 + assert fallback_config["circuit_time_window"] == 120.0 + assert fallback_config["circuit_cooldown_seconds"] == 60 + assert fallback_config["should_fallback"] == custom_should_fallback + assert fallback_config["track_stats"] is True + + # Verify primary_config comes from primary model + assert config["primary_config"] == {"model": "primary"} + + # Verify fallback_model_config comes from fallback model + assert config["fallback_model_config"] == {"model": "fallback"} + + # Verify stats is included (since track_stats=True) + stats = config["stats"] + assert isinstance(stats, dict) + assert "fallback_count" in stats + assert "primary_failures" in stats + assert "circuit_skips" in stats + assert "using_fallback" in stats + assert "circuit_open" in stats + assert "recent_failures" in stats + assert "circuit_open_until" in stats + + # Test with track_stats=False + fallback_model_no_stats = FallbackModel(primary=mock_primary, fallback=mock_fallback, track_stats=False) + + config_no_stats = fallback_model_no_stats.get_config() + assert config_no_stats["stats"] is None + + +# Streaming Tests + + +@pytest.mark.asyncio +async def test_stream_primary_success(mock_primary, mock_fallback, fallback_model, alist): + """Test streaming from primary model.""" + # Mock primary.stream to yield multiple events + test_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"index": 0}}, + {"contentBlockDelta": {"index": 0, "delta": {"text": "Hello"}}}, + {"contentBlockDelta": {"index": 0, "delta": {"text": " world"}}}, + {"contentBlockStop": {"index": 0}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_primary.stream.return_value = async_generator(test_events) + + # Call stream + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + # Collect all events + events = [] + async for event in fallback_model.stream(messages): + events.append(event) + + # Assert events match primary output + assert events == test_events + + # Verify primary was called with correct parameters + mock_primary.stream.assert_called_once_with( + messages=messages, tool_specs=None, system_prompt=None, tool_choice=None + ) + + # Verify fallback was not called + mock_fallback.stream.assert_not_called() + + # Verify stats show no fallback usage + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert stats["using_fallback"] is False + + +@pytest.mark.asyncio +async def test_stream_fallback_after_primary_failure(mock_primary, mock_fallback, fallback_model, alist): + """Test streaming from fallback after primary failure.""" + # Mock primary.stream to raise exception + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to yield events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"index": 0}}, + {"contentBlockDelta": {"index": 0, "delta": {"text": "Fallback"}}}, + {"contentBlockDelta": {"index": 0, "delta": {"text": " response"}}}, + {"contentBlockStop": {"index": 0}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + # Call stream + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + tool_specs = [{"name": "test_tool", "description": "Test tool"}] + system_prompt = "You are a helpful assistant" + + # Collect all events + events = [] + async for event in fallback_model.stream( + messages, tool_specs=tool_specs, system_prompt=system_prompt, tool_choice="auto" + ): + events.append(event) + + # Assert fallback events received + assert events == fallback_events + + # Verify primary was called first + mock_primary.stream.assert_called_once_with( + messages=messages, tool_specs=tool_specs, system_prompt=system_prompt, tool_choice="auto" + ) + + # Verify fallback was called with same parameters + mock_fallback.stream.assert_called_once_with( + messages=messages, tool_specs=tool_specs, system_prompt=system_prompt, tool_choice="auto" + ) + + # Verify stats show fallback usage + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 1 + assert stats["primary_failures"] == 1 + assert stats["using_fallback"] is True + + +@pytest.mark.asyncio +async def test_stream_circuit_open_direct_fallback(mock_primary, mock_fallback, alist): + """Test streaming directly from fallback when circuit open.""" + # Create FallbackModel with low threshold to easily open circuit + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=1, circuit_time_window=60.0 + ) + + # Mock primary.stream to raise exception to open circuit + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.stream to yield events + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + # First request to open circuit + await alist(fallback_model.stream(messages)) + + # Verify circuit is open + stats = fallback_model.get_stats() + assert stats["circuit_open"] is True + + # Reset mocks to track next call + mock_primary.stream.reset_mock() + mock_fallback.stream.reset_mock() + mock_fallback.stream.return_value = async_generator(fallback_events) + + # Second request should go directly to fallback + events = [] + async for event in fallback_model.stream( + messages, tool_specs=[{"name": "tool"}], system_prompt="Test prompt", tool_choice="required" + ): + events.append(event) + + # Assert primary.stream not called + mock_primary.stream.assert_not_called() + + # Assert fallback events received + assert events == fallback_events + + # Verify fallback was called with correct parameters + mock_fallback.stream.assert_called_once_with( + messages=messages, tool_specs=[{"name": "tool"}], system_prompt="Test prompt", tool_choice="required" + ) + + # Verify stats show circuit skip + stats = fallback_model.get_stats() + assert stats["circuit_skips"] == 1 + assert stats["using_fallback"] is True + + +# Structured Output Tests + + +@pytest.mark.asyncio +async def test_structured_output_primary_success(mock_primary, mock_fallback, fallback_model, alist): + """Test structured_output from primary model.""" + + from pydantic import BaseModel + + # Create test Pydantic model + class TestResponse(BaseModel): + answer: str + confidence: float + + # Mock primary.structured_output to yield events + test_events = [ + {"chunk_type": "structured_output_start"}, + {"chunk_type": "structured_output_delta", "data": {"answer": "Hello"}}, + {"chunk_type": "structured_output_delta", "data": {"confidence": 0.95}}, + {"chunk_type": "structured_output_complete", "data": TestResponse(answer="Hello", confidence=0.95)}, + ] + mock_primary.structured_output.return_value = async_generator(test_events) + + # Call structured_output + prompt = [{"role": "user", "content": [{"text": "Hi"}]}] + result = await alist( + fallback_model.structured_output( + output_model=TestResponse, prompt=prompt, system_prompt="You are a helpful assistant" + ) + ) + + # Assert events received + assert result == test_events + + # Verify primary was called with correct parameters + mock_primary.structured_output.assert_called_once_with( + output_model=TestResponse, prompt=prompt, system_prompt="You are a helpful assistant" + ) + + # Verify fallback was not called + mock_fallback.structured_output.assert_not_called() + + # Verify stats show no fallback usage + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert stats["using_fallback"] is False + + +@pytest.mark.asyncio +async def test_structured_output_fallback_after_primary_failure(mock_primary, mock_fallback, fallback_model, alist): + """Test structured_output from fallback after primary failure.""" + from pydantic import BaseModel + + # Create test Pydantic model + class TestResponse(BaseModel): + answer: str + confidence: float + + # Mock primary.structured_output to raise exception + mock_primary.structured_output.side_effect = ModelThrottledException("Primary throttled") + + # Mock fallback.structured_output to yield events + fallback_events = [ + {"chunk_type": "structured_output_start"}, + {"chunk_type": "structured_output_delta", "data": {"answer": "Fallback response"}}, + {"chunk_type": "structured_output_delta", "data": {"confidence": 0.85}}, + {"chunk_type": "structured_output_complete", "data": TestResponse(answer="Fallback response", confidence=0.85)}, + ] + mock_fallback.structured_output.return_value = async_generator(fallback_events) + + # Call structured_output + prompt = [{"role": "user", "content": [{"text": "Hi"}]}] + result = await alist( + fallback_model.structured_output( + output_model=TestResponse, prompt=prompt, system_prompt="You are a helpful assistant" + ) + ) + + # Assert fallback events received + assert result == fallback_events + + # Verify primary was called first + mock_primary.structured_output.assert_called_once_with( + output_model=TestResponse, prompt=prompt, system_prompt="You are a helpful assistant" + ) + + # Verify fallback was called after primary failure + mock_fallback.structured_output.assert_called_once_with( + output_model=TestResponse, prompt=prompt, system_prompt="You are a helpful assistant" + ) + + # Verify stats show fallback usage + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 1 + assert stats["primary_failures"] == 1 + assert stats["using_fallback"] is True + + +@pytest.mark.asyncio +async def test_structured_output_circuit_open(mock_primary, mock_fallback, alist): + """Test structured_output directly from fallback when circuit open.""" + from pydantic import BaseModel + + # Create test Pydantic model + class TestResponse(BaseModel): + answer: str + confidence: float + + # Create FallbackModel with low threshold to easily open circuit + fallback_model = FallbackModel( + primary=mock_primary, fallback=mock_fallback, circuit_failure_threshold=2, circuit_time_window=60.0 + ) + + # Open circuit by triggering failures + mock_primary.structured_output.side_effect = ModelThrottledException("Primary throttled") + fallback_events_for_opening = [ + {"chunk_type": "structured_output_complete", "data": TestResponse(answer="temp", confidence=0.5)} + ] + mock_fallback.structured_output.return_value = async_generator(fallback_events_for_opening) + + prompt = [{"role": "user", "content": [{"text": "Hi"}]}] + + # Trigger 2 failures to open circuit + await alist(fallback_model.structured_output(TestResponse, prompt)) + await alist(fallback_model.structured_output(TestResponse, prompt)) + + # Verify circuit is open + stats = fallback_model.get_stats() + assert stats["circuit_open"] is True + + # Reset mocks to track new calls + mock_primary.structured_output.reset_mock() + mock_fallback.structured_output.reset_mock() + + # Mock fallback.structured_output to yield events for the actual test + test_fallback_events = [ + {"chunk_type": "structured_output_start"}, + {"chunk_type": "structured_output_delta", "data": {"answer": "Circuit open response"}}, + {"chunk_type": "structured_output_delta", "data": {"confidence": 0.90}}, + { + "chunk_type": "structured_output_complete", + "data": TestResponse(answer="Circuit open response", confidence=0.90), + }, + ] + mock_fallback.structured_output.return_value = async_generator(test_fallback_events) + + # Call structured_output - should go directly to fallback + result = await alist( + fallback_model.structured_output( + output_model=TestResponse, prompt=prompt, system_prompt="You are a helpful assistant" + ) + ) + + # Assert primary not called (circuit is open) + mock_primary.structured_output.assert_not_called() + + # Assert fallback was called directly + mock_fallback.structured_output.assert_called_once_with( + output_model=TestResponse, prompt=prompt, system_prompt="You are a helpful assistant" + ) + + # Assert fallback events received + assert result == test_fallback_events + + # Verify circuit skip was recorded + stats = fallback_model.get_stats() + assert stats["circuit_skips"] == 1 + assert stats["using_fallback"] is True + + +# Enhanced Features Tests (Requirements 10 & 11) + + +def test_fallback_stats_typed_dict(mock_primary, mock_fallback): + """Test that get_stats() returns properly typed FallbackStats structure.""" + + # Create FallbackModel + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback) + + # Get stats + stats = fallback_model.get_stats() + + # Verify it's a FallbackStats (TypedDict) + assert isinstance(stats, dict) + + # Verify all required fields are present with correct types + assert isinstance(stats["fallback_count"], int) + assert isinstance(stats["primary_failures"], int) + assert isinstance(stats["circuit_skips"], int) + assert isinstance(stats["using_fallback"], bool) + assert isinstance(stats["circuit_open"], bool) + assert isinstance(stats["recent_failures"], int) + assert stats["circuit_open_until"] is None or isinstance(stats["circuit_open_until"], float) + assert isinstance(stats["primary_model_name"], str) + assert isinstance(stats["fallback_model_name"], str) + + # Verify initial values + assert stats["fallback_count"] == 0 + assert stats["primary_failures"] == 0 + assert stats["circuit_skips"] == 0 + assert stats["using_fallback"] is False + assert stats["circuit_open"] is False + assert stats["recent_failures"] == 0 + assert stats["circuit_open_until"] is None + + +def test_model_name_extraction(mock_primary, mock_fallback): + """Test model name extraction from configuration and class name fallback.""" + # Test with mock models that have get_config() returning model identifiers + mock_primary.get_config.return_value = {"model_id": "gpt-4", "other": "value"} + mock_fallback.get_config.return_value = {"model": "claude-3-haiku", "other": "value"} + + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback) + + # Test _get_model_name method directly + primary_name = fallback_model._get_model_name(mock_primary) + fallback_name = fallback_model._get_model_name(mock_fallback) + + assert primary_name == "gpt-4" + assert fallback_name == "claude-3-haiku" + + # Test with config that has "name" field + mock_primary.get_config.return_value = {"name": "custom-primary", "other": "value"} + primary_name = fallback_model._get_model_name(mock_primary) + assert primary_name == "custom-primary" + + # Test fallback to class name when no config identifiers + mock_primary.get_config.return_value = {"other": "value"} # No model identifiers + primary_name = fallback_model._get_model_name(mock_primary) + assert primary_name == "Model" # Should fall back to class name (mock has spec=Model) + + # Test fallback to class name when get_config() fails + mock_primary.get_config.side_effect = Exception("Config error") + primary_name = fallback_model._get_model_name(mock_primary) + assert primary_name == "Model" # Should fall back to class name (mock has spec=Model) + + +def test_model_names_in_statistics(mock_primary, mock_fallback): + """Test that model names are included in statistics.""" + # Configure mock models to return specific identifiers + mock_primary.get_config.return_value = {"model_id": "test-primary"} + mock_fallback.get_config.return_value = {"model_id": "test-fallback"} + + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback) + + # Get stats + stats = fallback_model.get_stats() + + # Verify model names are included + assert stats["primary_model_name"] == "test-primary" + assert stats["fallback_model_name"] == "test-fallback" + + # Test with class name fallback + mock_primary.get_config.return_value = {} # No identifiers + mock_fallback.get_config.return_value = {} # No identifiers + + stats = fallback_model.get_stats() + assert stats["primary_model_name"] == "Model" # Mock has spec=Model + assert stats["fallback_model_name"] == "Model" # Mock has spec=Model + + +@pytest.mark.asyncio +async def test_model_names_in_logging(mock_primary, mock_fallback, alist, caplog): + """Test that model names appear in logging output for better traceability.""" + import logging + + # Configure mock models with specific identifiers + mock_primary.get_config.return_value = {"model_id": "test-primary-model"} + mock_fallback.get_config.return_value = {"model_id": "test-fallback-model"} + + # Set logging level to capture info messages + caplog.set_level(logging.INFO) + + fallback_model = FallbackModel(primary=mock_primary, fallback=mock_fallback) + + # Check initialization logging includes model names + assert "primary=" in caplog.text + assert "fallback=" in caplog.text + + # Clear log for next test + caplog.clear() + + # Test fallback scenario logging + mock_primary.stream.side_effect = ModelThrottledException("Primary throttled") + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + await alist(fallback_model.stream(messages)) + + # Check that model names appear in fallback logging + assert "primary_model=" in caplog.text + assert "fallback_model=" in caplog.text + + +def test_override_decorators_present(): + """Test that @override decorators are present on overridden methods.""" + import inspect + + from strands.models.fallback import FallbackModel + + # Check that the methods have @override decorator by checking if they're marked as overrides + # Note: This is a basic check - in practice, type checkers like mypy would catch missing @override + + # Verify the methods exist and are callable + assert hasattr(FallbackModel, "stream") + assert callable(FallbackModel.stream) + + assert hasattr(FallbackModel, "structured_output") + assert callable(FallbackModel.structured_output) + + assert hasattr(FallbackModel, "update_config") + assert callable(FallbackModel.update_config) + + assert hasattr(FallbackModel, "get_config") + assert callable(FallbackModel.get_config) + + # Check method signatures match expected interface + stream_sig = inspect.signature(FallbackModel.stream) + assert "messages" in stream_sig.parameters + assert "tool_specs" in stream_sig.parameters + assert "system_prompt" in stream_sig.parameters + assert "tool_choice" in stream_sig.parameters + + structured_output_sig = inspect.signature(FallbackModel.structured_output) + assert "output_model" in structured_output_sig.parameters + assert "prompt" in structured_output_sig.parameters + assert "system_prompt" in structured_output_sig.parameters + + +@pytest.mark.asyncio +async def test_enhanced_debugging_information(mock_primary, mock_fallback, alist): + """Test enhanced debugging information in statistics and logging.""" + # Configure models with realistic identifiers + mock_primary.get_config.return_value = {"model_id": "gpt-4-turbo", "provider": "openai", "max_tokens": 4096} + mock_fallback.get_config.return_value = { + "model_id": "claude-3-sonnet-20240229", + "provider": "anthropic", + "max_tokens": 4096, + } + + fallback_model = FallbackModel( + primary=mock_primary, + fallback=mock_fallback, + circuit_failure_threshold=1, # Open circuit quickly + ) + + # Trigger a fallback scenario + mock_primary.stream.side_effect = ModelThrottledException("Rate limited") + fallback_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Fallback response"}}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + mock_fallback.stream.return_value = async_generator(fallback_events) + + messages = [{"role": "user", "content": [{"text": "Test message"}]}] + await alist(fallback_model.stream(messages)) + + # Get enhanced statistics + stats = fallback_model.get_stats() + + # Verify debugging information is present + assert stats["primary_model_name"] == "gpt-4-turbo" + assert stats["fallback_model_name"] == "claude-3-sonnet-20240229" + assert stats["fallback_count"] == 1 + assert stats["primary_failures"] == 1 + assert stats["using_fallback"] is True + + # Trigger circuit opening + await alist(fallback_model.stream(messages)) # This should open the circuit + + stats = fallback_model.get_stats() + assert stats["circuit_open"] is True + assert stats["circuit_skips"] == 1 + + # Verify that with this enhanced information, debugging is much easier + # A developer can now clearly see: + # 1. Which specific models are being used + # 2. Current circuit breaker state + # 3. Detailed failure and fallback counts + # 4. Whether the last request used fallback + assert all( + key in stats + for key in [ + "primary_model_name", + "fallback_model_name", + "fallback_count", + "primary_failures", + "circuit_skips", + "using_fallback", + "circuit_open", + "recent_failures", + "circuit_open_until", + ] + ) diff --git a/tests_integ/models/test_fallback_integration.py b/tests_integ/models/test_fallback_integration.py new file mode 100644 index 000000000..2bdb1d6fa --- /dev/null +++ b/tests_integ/models/test_fallback_integration.py @@ -0,0 +1,268 @@ +"""Integration tests for FallbackModel with real model providers.""" + +import os + +import pytest + +from strands import Agent +from strands.models import BedrockModel +from strands.models.anthropic import AnthropicModel +from strands.models.fallback import FallbackModel +from strands.models.openai import OpenAIModel +from tests_integ.models import providers + + +class TestFallbackModelIntegration: + """Integration tests for FallbackModel with real model instances.""" + + @providers.bedrock.mark + @pytest.mark.asyncio + async def test_same_provider_fallback_bedrock(self): + """Test FallbackModel with two BedrockModel instances.""" + # Use different model IDs - opus as primary, haiku as fallback + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") + + fallback_model = FallbackModel( + primary=primary, + fallback=fallback, + circuit_failure_threshold=1, # Open circuit quickly for testing + circuit_time_window=60.0, + circuit_cooldown_seconds=5, + ) + + # Test successful primary model usage + messages = [{"role": "user", "content": [{"text": "Say 'Hello from primary model'"}]}] + + events = [] + async for event in fallback_model.stream(messages=messages): + events.append(event) + + # Should have received events + assert len(events) > 0 + + # Check that primary was used (fallback_count should be 0) + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert not stats["using_fallback"] + + @pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ or "AWS_ACCESS_KEY_ID" not in os.environ, + reason="Both OPENAI_API_KEY and AWS credentials required for cross-provider test", + ) + @pytest.mark.asyncio + async def test_cross_provider_fallback_openai_bedrock(self): + """Test FallbackModel with OpenAI primary and Bedrock fallback.""" + primary = OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") + + fallback_model = FallbackModel( + primary=primary, + fallback=fallback, + circuit_failure_threshold=2, + circuit_time_window=60.0, + circuit_cooldown_seconds=10, + ) + + # Test successful cross-provider usage + messages = [{"role": "user", "content": [{"text": "Respond with exactly: 'Cross-provider test successful'"}]}] + + events = [] + async for event in fallback_model.stream(messages=messages): + events.append(event) + + # Should have received events + assert len(events) > 0 + + # Verify we can get configuration from both models + config = fallback_model.get_config() + assert "primary_config" in config + assert "fallback_model_config" in config + assert "fallback_config" in config + assert "stats" in config + + @providers.bedrock.mark + @pytest.mark.asyncio + async def test_agent_integration_with_fallback(self): + """Test FallbackModel used with Agent class.""" + # Create fallback model with two Bedrock instances + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") + + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) + + # Create agent with fallback model + agent = Agent(model=fallback_model, system_prompt="You are a helpful assistant. Keep responses brief.") + + # Send test message + response = await agent.invoke_async("What is 2 + 2?") + + # Assert response received + assert response is not None + assert response.message is not None + assert len(response.message["content"]) > 0 + assert response.message["content"][0]["text"] is not None + + # Check that the fallback model was used successfully + stats = fallback_model.get_stats() + assert isinstance(stats, dict) + assert "fallback_count" in stats + assert "primary_failures" in stats + + @pytest.mark.skipif( + "ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY required for Anthropic provider test" + ) + @pytest.mark.asyncio + async def test_cross_provider_anthropic_bedrock(self): + """Test FallbackModel with Anthropic primary and Bedrock fallback.""" + primary = AnthropicModel( + client_args={ + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + model_id="claude-3-7-sonnet-20250219", + max_tokens=512, + ) + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") + + fallback_model = FallbackModel(primary=primary, fallback=fallback) + + # Test structured output + from pydantic import BaseModel + + class TestResponse(BaseModel): + message: str + number: int + + messages = [{"role": "user", "content": [{"text": "Return a message 'test' and number 42"}]}] + + events = [] + async for event in fallback_model.structured_output(output_model=TestResponse, prompt=messages): + events.append(event) + + # Should have received events + assert len(events) > 0 + + # Check final event has the structured output + final_event = events[-1] + if "output" in final_event: + output = final_event["output"] + assert hasattr(output, "message") + assert hasattr(output, "number") + + @providers.bedrock.mark + @pytest.mark.asyncio + async def test_fallback_statistics_tracking(self): + """Test that statistics are properly tracked during integration tests.""" + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") + + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) + + # Make a successful request + messages = [{"role": "user", "content": [{"text": "Say hello"}]}] + + events = [] + async for event in fallback_model.stream(messages=messages): + events.append(event) + + # Check statistics + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 # No fallback should have occurred + assert stats["primary_failures"] == 0 # No failures + assert not stats["using_fallback"] # Not using fallback + assert not stats["circuit_open"] # Circuit should be closed + + # Test configuration retrieval + config = fallback_model.get_config() + assert config["stats"] is not None + assert config["fallback_config"]["track_stats"] is True + + # Test stats reset + fallback_model.reset_stats() + reset_stats = fallback_model.get_stats() + assert reset_stats["fallback_count"] == 0 + assert reset_stats["primary_failures"] == 0 + + @providers.bedrock.mark + @pytest.mark.asyncio + async def test_tool_calling_with_fallback_model(self): + """Test that tool_specs and tool_choice parameters work with FallbackModel.""" + # Create fallback model with two Bedrock instances + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") + + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) + + # Define a simple tool spec + tool_specs = [ + { + "name": "get_weather", + "description": "Get weather information for a location", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The location to get weather for"} + }, + "required": ["location"], + } + }, + } + ] + + tool_choice = {"auto": {}} + + # Test message that might trigger tool use + messages = [{"role": "user", "content": [{"text": "What's the weather in Seattle?"}]}] + + # Stream with tool parameters + events = [] + async for event in fallback_model.stream(messages=messages, tool_specs=tool_specs, tool_choice=tool_choice): + events.append(event) + + # Should have received events + assert len(events) > 0 + + # Verify primary was used (no fallback) + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert not stats["using_fallback"] + + @providers.bedrock.mark + @pytest.mark.asyncio + async def test_tool_calling_with_agent_and_fallback_model(self): + """Test that FallbackModel works with Agent class when tools are provided.""" + # Create fallback model with two Bedrock instances + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") + + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) + + # Create a simple tool using the strands tool decorator + from strands import tool + + @tool + def get_current_time(timezone: str = "UTC") -> dict: + """Get the current time in a specific timezone.""" + return {"status": "success", "content": [{"text": f"Current time in {timezone}: 12:00 PM"}]} + + # Create agent with fallback model and tool + agent = Agent(model=fallback_model, tools=[get_current_time], system_prompt="You are a helpful assistant.") + + # Send test message + response = await agent.invoke_async("What time is it?") + + # Assert response received + assert response is not None + assert response.message is not None + assert len(response.message["content"]) > 0 + + # Verify primary was used (no fallback) + stats = fallback_model.get_stats() + assert stats["fallback_count"] == 0 + assert not stats["using_fallback"]