diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:06:54 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-01 16:06:54 +0900 |
| commit | 558de36fdd886a117c66dee73850cc219c86b2a4 (patch) | |
| tree | ba3a534f928f2e66fe54d22f9654b737d491bd84 /shared | |
| parent | cd15c3f64d00c6c97f738d59f719bb0938d9f7cb (diff) | |
feat(shared): add retry with backoff and circuit breaker
Diffstat (limited to 'shared')
| -rw-r--r-- | shared/src/shared/resilience.py | 105 | ||||
| -rw-r--r-- | shared/tests/test_resilience.py | 138 |
2 files changed, 243 insertions, 0 deletions
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py new file mode 100644 index 0000000..d4e963b --- /dev/null +++ b/shared/src/shared/resilience.py @@ -0,0 +1,105 @@ +"""Retry with exponential backoff and circuit breaker utilities.""" + +from __future__ import annotations + +import asyncio +import enum +import functools +import logging +import random +import time +from typing import Any, Callable + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# retry_with_backoff +# --------------------------------------------------------------------------- + + +def retry_with_backoff( + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 60.0, +) -> Callable: + """Decorator that retries an async function with exponential backoff + jitter.""" + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + last_exc: BaseException | None = None + for attempt in range(max_retries + 1): + try: + return await func(*args, **kwargs) + except Exception as exc: + last_exc = exc + if attempt < max_retries: + delay = min(base_delay * (2 ** attempt), max_delay) + jitter = delay * random.uniform(0, 0.5) + total_delay = delay + jitter + logger.warning( + "Retry %d/%d for %s after error: %s (delay=%.3fs)", + attempt + 1, + max_retries, + func.__name__, + exc, + total_delay, + ) + await asyncio.sleep(total_delay) + raise last_exc # type: ignore[misc] + + return wrapper + + return decorator + + +# --------------------------------------------------------------------------- +# CircuitBreaker +# --------------------------------------------------------------------------- + + +class CircuitState(enum.Enum): + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" + + +class CircuitBreaker: + """Simple circuit breaker implementation.""" + + def __init__( + self, + failure_threshold: int = 5, + recovery_timeout: float = 60.0, + ) -> None: + self._failure_threshold = failure_threshold + self._recovery_timeout = recovery_timeout + self._failure_count: int = 0 + self._state = CircuitState.CLOSED + self._opened_at: float = 0.0 + + @property + def state(self) -> CircuitState: + return self._state + + def allow_request(self) -> bool: + if self._state == CircuitState.CLOSED: + return True + if self._state == CircuitState.OPEN: + if time.monotonic() - self._opened_at >= self._recovery_timeout: + self._state = CircuitState.HALF_OPEN + return True + return False + # HALF_OPEN + return True + + def record_success(self) -> None: + self._failure_count = 0 + self._state = CircuitState.CLOSED + + def record_failure(self) -> None: + self._failure_count += 1 + if self._failure_count >= self._failure_threshold: + self._state = CircuitState.OPEN + self._opened_at = time.monotonic() diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py new file mode 100644 index 0000000..514bcc2 --- /dev/null +++ b/shared/tests/test_resilience.py @@ -0,0 +1,138 @@ +"""Tests for retry with backoff and circuit breaker.""" +import asyncio +import time + +import pytest + +from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff + + +# --------------------------------------------------------------------------- +# retry_with_backoff tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_retry_succeeds_first_try(): + call_count = 0 + + @retry_with_backoff(max_retries=3, base_delay=0.01) + async def succeed(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await succeed() + assert result == "ok" + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_retry_succeeds_after_failures(): + call_count = 0 + + @retry_with_backoff(max_retries=3, base_delay=0.01) + async def flaky(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ValueError("not yet") + return "recovered" + + result = await flaky() + assert result == "recovered" + assert call_count == 3 + + +@pytest.mark.asyncio +async def test_retry_raises_after_max_retries(): + call_count = 0 + + @retry_with_backoff(max_retries=3, base_delay=0.01) + async def always_fail(): + nonlocal call_count + call_count += 1 + raise RuntimeError("permanent") + + with pytest.raises(RuntimeError, match="permanent"): + await always_fail() + # 1 initial + 3 retries = 4 calls + assert call_count == 4 + + +@pytest.mark.asyncio +async def test_retry_respects_max_delay(): + """Backoff should be capped at max_delay.""" + @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) + async def always_fail(): + raise RuntimeError("fail") + + start = time.monotonic() + with pytest.raises(RuntimeError): + await always_fail() + elapsed = time.monotonic() - start + # With max_delay=0.02 and 2 retries, total delay should be small + assert elapsed < 0.5 + + +# --------------------------------------------------------------------------- +# CircuitBreaker tests +# --------------------------------------------------------------------------- + + +def test_circuit_starts_closed(): + cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) + assert cb.state == CircuitState.CLOSED + assert cb.allow_request() is True + + +def test_circuit_opens_after_threshold(): + cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) + for _ in range(3): + cb.record_failure() + assert cb.state == CircuitState.OPEN + assert cb.allow_request() is False + + +def test_circuit_rejects_when_open(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + assert cb.allow_request() is False + + +def test_circuit_half_open_after_timeout(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + + time.sleep(0.06) + assert cb.allow_request() is True + assert cb.state == CircuitState.HALF_OPEN + + +def test_circuit_closes_on_success(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + + time.sleep(0.06) + cb.allow_request() # triggers HALF_OPEN + assert cb.state == CircuitState.HALF_OPEN + + cb.record_success() + assert cb.state == CircuitState.CLOSED + assert cb.allow_request() is True + + +def test_circuit_reopens_on_failure_in_half_open(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) + cb.record_failure() + cb.record_failure() + time.sleep(0.06) + cb.allow_request() # HALF_OPEN + cb.record_failure() + assert cb.state == CircuitState.OPEN |
