diff --git a/python/packages/autogen-core/src/autogen_core/_queue.py b/python/packages/autogen-core/src/autogen_core/_queue.py index 699921a37f5d..c5a4bbc80619 100644 --- a/python/packages/autogen-core/src/autogen_core/_queue.py +++ b/python/packages/autogen-core/src/autogen_core/_queue.py @@ -122,7 +122,7 @@ async def put(self, item: T) -> None: self._putters.append(putter) try: await putter - except: + except Exception: putter.cancel() # Just in case putter is not done yet. try: # Clean self._putters from canceled putters. @@ -169,7 +169,7 @@ async def get(self) -> T: self._getters.append(getter) try: await getter - except: + except Exception: getter.cancel() # Just in case getter is not done yet. try: # Clean self._getters from canceled getters. diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/retry_agent.py b/python/packages/autogen-ext/src/autogen_ext/agents/retry_agent.py new file mode 100644 index 000000000000..874089d18ccb --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/agents/retry_agent.py @@ -0,0 +1,244 @@ +"""Retry agent wrapper with exponential backoff, circuit breaker, and fallback support.""" + +import asyncio +import logging +import random +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Optional, Sequence, Type + +logger = logging.getLogger(__name__) + + +@dataclass +class RetryConfig: + """Configuration for retry behavior.""" + + max_retries: int = 3 + base_delay: float = 1.0 + max_delay: float = 60.0 + exponential_base: float = 2.0 + jitter: bool = True + retry_on: Optional[tuple[Type[Exception], ...]] = None + timeout: Optional[float] = None + + +class CircuitBreakerState(Enum): + """States for the circuit breaker pattern.""" + + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" + + +class CircuitBreaker: + """Circuit breaker to prevent repeated calls to a failing service.""" + + def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 30.0) -> None: + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self._state = CircuitBreakerState.CLOSED + self._consecutive_failures = 0 + self._last_failure_time: Optional[float] = None + + @property + def state(self) -> CircuitBreakerState: + if self._state == CircuitBreakerState.OPEN and self._last_failure_time is not None: + elapsed = time.monotonic() - self._last_failure_time + if elapsed >= self.recovery_timeout: + self._state = CircuitBreakerState.HALF_OPEN + logger.info("Circuit breaker transitioned to HALF_OPEN after %.1fs", elapsed) + return self._state + + def record_success(self) -> None: + """Record a successful execution, resetting the breaker.""" + if self._state in (CircuitBreakerState.HALF_OPEN, CircuitBreakerState.CLOSED): + self._consecutive_failures = 0 + self._state = CircuitBreakerState.CLOSED + logger.debug("Circuit breaker reset to CLOSED after success") + + def record_failure(self) -> None: + """Record a failure and potentially trip the breaker.""" + self._consecutive_failures += 1 + self._last_failure_time = time.monotonic() + if self._consecutive_failures >= self.failure_threshold: + self._state = CircuitBreakerState.OPEN + logger.warning( + "Circuit breaker tripped to OPEN after %d consecutive failures", + self._consecutive_failures, + ) + + def can_execute(self) -> bool: + """Check whether execution is allowed under the current state.""" + current_state = self.state + if current_state == CircuitBreakerState.CLOSED: + return True + if current_state == CircuitBreakerState.HALF_OPEN: + return True + return False + + +@dataclass +class RetryMetrics: + """Tracks retry statistics for observability.""" + + total_attempts: int = 0 + successful_attempts: int = 0 + failed_attempts: int = 0 + total_retry_delay: float = 0.0 + circuit_breaker_trips: int = 0 + last_error: Optional[Exception] = field(default=None, repr=False) + + +class RetryAgent: + """Wraps any agent with retry logic, circuit breaking, and optional fallback. + + Args: + agent: The inner agent to wrap (duck-typed, must have an ``execute`` method). + config: Retry configuration controlling backoff and limits. + fallback_agent: An optional agent invoked when all retries are exhausted. + circuit_breaker: An optional CircuitBreaker instance for failure isolation. + """ + + def __init__( + self, + agent: Any, + config: Optional[RetryConfig] = None, + fallback_agent: Any = None, + circuit_breaker: Optional[CircuitBreaker] = None, + ) -> None: + self._agent = agent + self._config = config or RetryConfig() + self._fallback_agent = fallback_agent + self._circuit_breaker = circuit_breaker + self._metrics = RetryMetrics() + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def execute(self, *args: Any, **kwargs: Any) -> Any: + """Execute the wrapped agent with retry logic. + + Raises the last encountered exception when all retries (and the + optional fallback) are exhausted. + """ + last_error: Optional[Exception] = None + + for attempt in range(1, self._config.max_retries + 2): # attempt 1 = first try + # Circuit breaker gate + if self._circuit_breaker and not self._circuit_breaker.can_execute(): + self._metrics.circuit_breaker_trips += 1 + logger.warning("Circuit breaker is OPEN – skipping attempt %d", attempt) + break + + self._metrics.total_attempts += 1 + + try: + result = await self._execute_with_timeout(*args, **kwargs) + self._metrics.successful_attempts += 1 + if self._circuit_breaker: + self._circuit_breaker.record_success() + return result + except Exception as exc: + last_error = exc + self._metrics.failed_attempts += 1 + self._metrics.last_error = exc + + if self._circuit_breaker: + self._circuit_breaker.record_failure() + + if not self._should_retry(exc, attempt): + logger.debug("Not retrying after attempt %d: %s", attempt, exc) + break + + delay = self._calculate_delay(attempt) + self._metrics.total_retry_delay += delay + logger.info( + "Attempt %d failed (%s). Retrying in %.2fs …", + attempt, + type(exc).__name__, + delay, + ) + await asyncio.sleep(delay) + + # All retries exhausted – try fallback + if self._fallback_agent is not None: + logger.info("All retries exhausted. Invoking fallback agent.") + try: + return await self._fallback_agent.execute(*args, **kwargs) + except Exception as fallback_exc: + logger.error("Fallback agent also failed: %s", fallback_exc) + raise fallback_exc from last_error + + if last_error is not None: + raise last_error + raise RuntimeError("RetryAgent finished without a result or error") + + def get_metrics(self) -> RetryMetrics: + """Return a snapshot of the current retry metrics.""" + return self._metrics + + def reset_metrics(self) -> None: + """Reset all tracked metrics.""" + self._metrics = RetryMetrics() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _calculate_delay(self, attempt: int) -> float: + """Return the backoff delay for the given *attempt* number.""" + delay = min( + self._config.base_delay * (self._config.exponential_base ** (attempt - 1)), + self._config.max_delay, + ) + if self._config.jitter: + delay = delay * random.uniform(0.5, 1.0) + return delay + + def _should_retry(self, error: Exception, attempt: int) -> bool: + """Decide whether *error* on *attempt* is eligible for a retry.""" + if attempt > self._config.max_retries: + return False + if self._config.retry_on is not None: + return isinstance(error, self._config.retry_on) + return True + + async def _execute_with_timeout(self, *args: Any, **kwargs: Any) -> Any: + """Run the inner agent's ``execute`` with an optional timeout.""" + coro = self._agent.execute(*args, **kwargs) + if self._config.timeout is not None: + try: + return await asyncio.wait_for(coro, timeout=self._config.timeout) + except asyncio.TimeoutError: + raise asyncio.TimeoutError( + f"Agent execution timed out after {self._config.timeout}s" + ) + return await coro + + +# ------------------------------------------------------------------ +# Standalone utility +# ------------------------------------------------------------------ + + +async def retry_with_backoff( + func: Callable[..., Any], + config: Optional[RetryConfig] = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Convenience helper that retries an async callable with backoff. + + ``func`` must be an async function (coroutine function). The helper + creates a thin adapter and delegates to :class:`RetryAgent`. + """ + + class _FuncAdapter: + async def execute(self, *a: Any, **kw: Any) -> Any: + return await func(*a, **kw) + + agent = RetryAgent(_FuncAdapter(), config=config) + return await agent.execute(*args, **kwargs) diff --git a/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py b/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py index 17c4b16a2c15..c2864a7d890f 100644 --- a/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py +++ b/python/packages/autogen-ext/src/autogen_ext/code_executors/azure/_azure_container_code_executor.py @@ -100,11 +100,13 @@ def __init__( Callable[..., Any], FunctionWithRequirementsStr, ] - ] = [], + ] | None = None, functions_module: str = "functions", suppress_result_output: bool = False, session_id: Optional[str] = None, ): + if functions is None: + functions = [] if timeout < 1: raise ValueError("Timeout must be greater than or equal to 1.") diff --git a/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py b/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py index 701658572141..ec44cbf6c3c8 100644 --- a/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py +++ b/python/packages/autogen-ext/src/autogen_ext/code_executors/docker/_docker_code_executor.py @@ -172,13 +172,15 @@ def __init__( Callable[..., Any], FunctionWithRequirementsStr, ] - ] = [], + ] | None = None, functions_module: str = "functions", extra_volumes: Optional[Dict[str, Dict[str, str]]] = None, extra_hosts: Optional[Dict[str, str]] = None, init_command: Optional[str] = None, delete_tmp_files: bool = False, ): + if functions is None: + functions = [] if timeout < 1: raise ValueError("Timeout must be greater than or equal to 1.") diff --git a/python/packages/autogen-ext/src/autogen_ext/code_executors/local/__init__.py b/python/packages/autogen-ext/src/autogen_ext/code_executors/local/__init__.py index f21d9fe4b8ef..e1bb685c1278 100644 --- a/python/packages/autogen-ext/src/autogen_ext/code_executors/local/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/code_executors/local/__init__.py @@ -154,7 +154,7 @@ def __init__( Callable[..., Any], FunctionWithRequirementsStr, ] - ] = [], + ] | None = None, functions_module: str = "functions", cleanup_temp_files: bool = True, virtual_env_context: Optional[SimpleNamespace] = None, @@ -168,6 +168,8 @@ def __init__( stacklevel=2, ) + if functions is None: + functions = [] if timeout < 1: raise ValueError("Timeout must be greater than or equal to 1.") self._timeout = timeout diff --git a/python/packages/autogen-ext/tests/test_retry_agent.py b/python/packages/autogen-ext/tests/test_retry_agent.py new file mode 100644 index 000000000000..ebaba3712236 --- /dev/null +++ b/python/packages/autogen-ext/tests/test_retry_agent.py @@ -0,0 +1,332 @@ +"""Tests for RetryAgent with exponential backoff, circuit breaker, and fallback.""" + +import asyncio + +import pytest + +from autogen_ext.agents.retry_agent import ( + CircuitBreaker, + CircuitBreakerState, + RetryAgent, + RetryConfig, + RetryMetrics, + retry_with_backoff, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class MockAgent: + """Simple agent that succeeds on every call.""" + + def __init__(self, return_value: str = "ok") -> None: + self._return_value = return_value + self.call_count = 0 + + async def execute(self, *args, **kwargs): + self.call_count += 1 + return self._return_value + + +class FailNTimesAgent: + """Agent that raises for the first *n* calls, then succeeds.""" + + def __init__(self, fail_count: int, exception: Exception | None = None) -> None: + self._fail_count = fail_count + self._exception = exception or RuntimeError("transient failure") + self.call_count = 0 + + async def execute(self, *args, **kwargs): + self.call_count += 1 + if self.call_count <= self._fail_count: + raise self._exception + return "recovered" + + +class AlwaysFailAgent: + """Agent that always raises.""" + + def __init__(self, exception: Exception | None = None) -> None: + self._exception = exception or RuntimeError("permanent failure") + self.call_count = 0 + + async def execute(self, *args, **kwargs): + self.call_count += 1 + raise self._exception + + +class SlowAgent: + """Agent that sleeps longer than expected.""" + + def __init__(self, delay: float = 5.0) -> None: + self._delay = delay + + async def execute(self, *args, **kwargs): + await asyncio.sleep(self._delay) + return "slow_ok" + + +# --------------------------------------------------------------------------- +# Tests – basic execution +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_successful_execution_no_retries(): + agent = MockAgent(return_value="hello") + wrapper = RetryAgent(agent, RetryConfig(max_retries=3)) + + result = await wrapper.execute() + + assert result == "hello" + assert agent.call_count == 1 + metrics = wrapper.get_metrics() + assert metrics.total_attempts == 1 + assert metrics.successful_attempts == 1 + assert metrics.failed_attempts == 0 + + +@pytest.mark.asyncio +async def test_retry_on_transient_failure(): + agent = FailNTimesAgent(fail_count=2) + config = RetryConfig(max_retries=3, base_delay=0.01, jitter=False) + wrapper = RetryAgent(agent, config) + + result = await wrapper.execute() + + assert result == "recovered" + assert agent.call_count == 3 + metrics = wrapper.get_metrics() + assert metrics.successful_attempts == 1 + assert metrics.failed_attempts == 2 + + +@pytest.mark.asyncio +async def test_max_retries_exceeded(): + agent = AlwaysFailAgent() + config = RetryConfig(max_retries=2, base_delay=0.01, jitter=False) + wrapper = RetryAgent(agent, config) + + with pytest.raises(RuntimeError, match="permanent failure"): + await wrapper.execute() + + assert agent.call_count == 3 # 1 initial + 2 retries + metrics = wrapper.get_metrics() + assert metrics.failed_attempts == 3 + assert metrics.successful_attempts == 0 + + +# --------------------------------------------------------------------------- +# Tests – backoff timing +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_exponential_backoff_delays(): + config = RetryConfig(base_delay=1.0, exponential_base=2.0, jitter=False) + wrapper = RetryAgent(MockAgent(), config) + + assert wrapper._calculate_delay(1) == pytest.approx(1.0) + assert wrapper._calculate_delay(2) == pytest.approx(2.0) + assert wrapper._calculate_delay(3) == pytest.approx(4.0) + + +@pytest.mark.asyncio +async def test_max_delay_cap(): + config = RetryConfig(base_delay=10.0, max_delay=15.0, exponential_base=2.0, jitter=False) + wrapper = RetryAgent(MockAgent(), config) + + assert wrapper._calculate_delay(5) == pytest.approx(15.0) + + +@pytest.mark.asyncio +async def test_jitter_randomisation(): + config = RetryConfig(base_delay=2.0, jitter=True) + wrapper = RetryAgent(MockAgent(), config) + + delays = [wrapper._calculate_delay(1) for _ in range(50)] + assert all(1.0 <= d <= 2.0 for d in delays) + assert len(set(delays)) > 1 # not all identical + + +# --------------------------------------------------------------------------- +# Tests – circuit breaker +# --------------------------------------------------------------------------- + +def test_circuit_breaker_initial_state(): + cb = CircuitBreaker(failure_threshold=3) + assert cb.state == CircuitBreakerState.CLOSED + assert cb.can_execute() is True + + +def test_circuit_breaker_trips_on_threshold(): + cb = CircuitBreaker(failure_threshold=3) + for _ in range(3): + cb.record_failure() + assert cb.state == CircuitBreakerState.OPEN + assert cb.can_execute() is False + + +def test_circuit_breaker_resets_on_success(): + cb = CircuitBreaker(failure_threshold=3) + cb.record_failure() + cb.record_failure() + cb.record_success() + assert cb.state == CircuitBreakerState.CLOSED + assert cb.can_execute() is True + + +def test_circuit_breaker_half_open_after_recovery(monkeypatch): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=1.0) + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitBreakerState.OPEN + + # Fast-forward time past the recovery timeout + import time as _time + + original_last = cb._last_failure_time + monkeypatch.setattr(cb, "_last_failure_time", original_last - 2.0) + assert cb.state == CircuitBreakerState.HALF_OPEN + assert cb.can_execute() is True + + +@pytest.mark.asyncio +async def test_circuit_breaker_blocks_execution(): + cb = CircuitBreaker(failure_threshold=1) + cb.record_failure() # trips immediately + + agent = MockAgent() + config = RetryConfig(max_retries=3, base_delay=0.01, jitter=False) + wrapper = RetryAgent(agent, config, circuit_breaker=cb) + + with pytest.raises(RuntimeError, match="without a result"): + await wrapper.execute() + + assert agent.call_count == 0 + assert wrapper.get_metrics().circuit_breaker_trips >= 1 + + +# --------------------------------------------------------------------------- +# Tests – fallback agent +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_fallback_agent_activation(): + primary = AlwaysFailAgent() + fallback = MockAgent(return_value="fallback_result") + config = RetryConfig(max_retries=1, base_delay=0.01, jitter=False) + wrapper = RetryAgent(primary, config, fallback_agent=fallback) + + result = await wrapper.execute() + + assert result == "fallback_result" + assert primary.call_count == 2 # 1 initial + 1 retry + assert fallback.call_count == 1 + + +@pytest.mark.asyncio +async def test_fallback_agent_also_fails(): + primary = AlwaysFailAgent(RuntimeError("primary")) + fallback = AlwaysFailAgent(ValueError("fallback")) + config = RetryConfig(max_retries=0, base_delay=0.01, jitter=False) + wrapper = RetryAgent(primary, config, fallback_agent=fallback) + + with pytest.raises(ValueError, match="fallback"): + await wrapper.execute() + + +# --------------------------------------------------------------------------- +# Tests – timeout handling +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_timeout_raises_on_slow_agent(): + agent = SlowAgent(delay=5.0) + config = RetryConfig(max_retries=0, timeout=0.05) + wrapper = RetryAgent(agent, config) + + with pytest.raises(asyncio.TimeoutError): + await wrapper.execute() + + +@pytest.mark.asyncio +async def test_timeout_retries_then_fails(): + agent = SlowAgent(delay=5.0) + config = RetryConfig(max_retries=2, timeout=0.05, base_delay=0.01, jitter=False) + wrapper = RetryAgent(agent, config) + + with pytest.raises(asyncio.TimeoutError): + await wrapper.execute() + + assert wrapper.get_metrics().failed_attempts == 3 + + +# --------------------------------------------------------------------------- +# Tests – selective retry_on +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_retry_on_filters_exception_types(): + agent = AlwaysFailAgent(ValueError("not retryable")) + config = RetryConfig(max_retries=3, base_delay=0.01, retry_on=(RuntimeError,)) + wrapper = RetryAgent(agent, config) + + with pytest.raises(ValueError, match="not retryable"): + await wrapper.execute() + + assert agent.call_count == 1 # no retries for ValueError + + +# --------------------------------------------------------------------------- +# Tests – metrics +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_metrics_tracking(): + agent = FailNTimesAgent(fail_count=1) + config = RetryConfig(max_retries=3, base_delay=0.01, jitter=False) + wrapper = RetryAgent(agent, config) + + await wrapper.execute() + m = wrapper.get_metrics() + + assert m.total_attempts == 2 + assert m.successful_attempts == 1 + assert m.failed_attempts == 1 + assert m.total_retry_delay > 0 + + +@pytest.mark.asyncio +async def test_reset_metrics(): + agent = MockAgent() + wrapper = RetryAgent(agent, RetryConfig()) + + await wrapper.execute() + wrapper.reset_metrics() + m = wrapper.get_metrics() + + assert m.total_attempts == 0 + assert m.successful_attempts == 0 + + +# --------------------------------------------------------------------------- +# Tests – standalone utility +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_retry_with_backoff_utility(): + call_count = 0 + + async def flaky_func(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError("not yet") + return "done" + + config = RetryConfig(max_retries=5, base_delay=0.01, jitter=False) + result = await retry_with_backoff(flaky_func, config) + + assert result == "done" + assert call_count == 3