summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 16:06:54 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-01 16:06:54 +0900
commit558de36fdd886a117c66dee73850cc219c86b2a4 (patch)
treeba3a534f928f2e66fe54d22f9654b737d491bd84 /shared
parentcd15c3f64d00c6c97f738d59f719bb0938d9f7cb (diff)
feat(shared): add retry with backoff and circuit breaker
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/resilience.py105
-rw-r--r--shared/tests/test_resilience.py138
2 files changed, 243 insertions, 0 deletions
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py
new file mode 100644
index 0000000..d4e963b
--- /dev/null
+++ b/shared/src/shared/resilience.py
@@ -0,0 +1,105 @@
+"""Retry with exponential backoff and circuit breaker utilities."""
+
+from __future__ import annotations
+
+import asyncio
+import enum
+import functools
+import logging
+import random
+import time
+from typing import Any, Callable
+
+logger = logging.getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# retry_with_backoff
+# ---------------------------------------------------------------------------
+
+
+def retry_with_backoff(
+ max_retries: int = 3,
+ base_delay: float = 1.0,
+ max_delay: float = 60.0,
+) -> Callable:
+ """Decorator that retries an async function with exponential backoff + jitter."""
+
+ def decorator(func: Callable) -> Callable:
+ @functools.wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ last_exc: BaseException | None = None
+ for attempt in range(max_retries + 1):
+ try:
+ return await func(*args, **kwargs)
+ except Exception as exc:
+ last_exc = exc
+ if attempt < max_retries:
+ delay = min(base_delay * (2 ** attempt), max_delay)
+ jitter = delay * random.uniform(0, 0.5)
+ total_delay = delay + jitter
+ logger.warning(
+ "Retry %d/%d for %s after error: %s (delay=%.3fs)",
+ attempt + 1,
+ max_retries,
+ func.__name__,
+ exc,
+ total_delay,
+ )
+ await asyncio.sleep(total_delay)
+ raise last_exc # type: ignore[misc]
+
+ return wrapper
+
+ return decorator
+
+
+# ---------------------------------------------------------------------------
+# CircuitBreaker
+# ---------------------------------------------------------------------------
+
+
+class CircuitState(enum.Enum):
+ CLOSED = "closed"
+ OPEN = "open"
+ HALF_OPEN = "half_open"
+
+
+class CircuitBreaker:
+ """Simple circuit breaker implementation."""
+
+ def __init__(
+ self,
+ failure_threshold: int = 5,
+ recovery_timeout: float = 60.0,
+ ) -> None:
+ self._failure_threshold = failure_threshold
+ self._recovery_timeout = recovery_timeout
+ self._failure_count: int = 0
+ self._state = CircuitState.CLOSED
+ self._opened_at: float = 0.0
+
+ @property
+ def state(self) -> CircuitState:
+ return self._state
+
+ def allow_request(self) -> bool:
+ if self._state == CircuitState.CLOSED:
+ return True
+ if self._state == CircuitState.OPEN:
+ if time.monotonic() - self._opened_at >= self._recovery_timeout:
+ self._state = CircuitState.HALF_OPEN
+ return True
+ return False
+ # HALF_OPEN
+ return True
+
+ def record_success(self) -> None:
+ self._failure_count = 0
+ self._state = CircuitState.CLOSED
+
+ def record_failure(self) -> None:
+ self._failure_count += 1
+ if self._failure_count >= self._failure_threshold:
+ self._state = CircuitState.OPEN
+ self._opened_at = time.monotonic()
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
new file mode 100644
index 0000000..514bcc2
--- /dev/null
+++ b/shared/tests/test_resilience.py
@@ -0,0 +1,138 @@
+"""Tests for retry with backoff and circuit breaker."""
+import asyncio
+import time
+
+import pytest
+
+from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff
+
+
+# ---------------------------------------------------------------------------
+# retry_with_backoff tests
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_retry_succeeds_first_try():
+ call_count = 0
+
+ @retry_with_backoff(max_retries=3, base_delay=0.01)
+ async def succeed():
+ nonlocal call_count
+ call_count += 1
+ return "ok"
+
+ result = await succeed()
+ assert result == "ok"
+ assert call_count == 1
+
+
+@pytest.mark.asyncio
+async def test_retry_succeeds_after_failures():
+ call_count = 0
+
+ @retry_with_backoff(max_retries=3, base_delay=0.01)
+ async def flaky():
+ nonlocal call_count
+ call_count += 1
+ if call_count < 3:
+ raise ValueError("not yet")
+ return "recovered"
+
+ result = await flaky()
+ assert result == "recovered"
+ assert call_count == 3
+
+
+@pytest.mark.asyncio
+async def test_retry_raises_after_max_retries():
+ call_count = 0
+
+ @retry_with_backoff(max_retries=3, base_delay=0.01)
+ async def always_fail():
+ nonlocal call_count
+ call_count += 1
+ raise RuntimeError("permanent")
+
+ with pytest.raises(RuntimeError, match="permanent"):
+ await always_fail()
+ # 1 initial + 3 retries = 4 calls
+ assert call_count == 4
+
+
+@pytest.mark.asyncio
+async def test_retry_respects_max_delay():
+ """Backoff should be capped at max_delay."""
+ @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02)
+ async def always_fail():
+ raise RuntimeError("fail")
+
+ start = time.monotonic()
+ with pytest.raises(RuntimeError):
+ await always_fail()
+ elapsed = time.monotonic() - start
+ # With max_delay=0.02 and 2 retries, total delay should be small
+ assert elapsed < 0.5
+
+
+# ---------------------------------------------------------------------------
+# CircuitBreaker tests
+# ---------------------------------------------------------------------------
+
+
+def test_circuit_starts_closed():
+ cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05)
+ assert cb.state == CircuitState.CLOSED
+ assert cb.allow_request() is True
+
+
+def test_circuit_opens_after_threshold():
+ cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0)
+ for _ in range(3):
+ cb.record_failure()
+ assert cb.state == CircuitState.OPEN
+ assert cb.allow_request() is False
+
+
+def test_circuit_rejects_when_open():
+ cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0)
+ cb.record_failure()
+ cb.record_failure()
+ assert cb.state == CircuitState.OPEN
+ assert cb.allow_request() is False
+
+
+def test_circuit_half_open_after_timeout():
+ cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
+ cb.record_failure()
+ cb.record_failure()
+ assert cb.state == CircuitState.OPEN
+
+ time.sleep(0.06)
+ assert cb.allow_request() is True
+ assert cb.state == CircuitState.HALF_OPEN
+
+
+def test_circuit_closes_on_success():
+ cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
+ cb.record_failure()
+ cb.record_failure()
+ assert cb.state == CircuitState.OPEN
+
+ time.sleep(0.06)
+ cb.allow_request() # triggers HALF_OPEN
+ assert cb.state == CircuitState.HALF_OPEN
+
+ cb.record_success()
+ assert cb.state == CircuitState.CLOSED
+ assert cb.allow_request() is True
+
+
+def test_circuit_reopens_on_failure_in_half_open():
+ cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
+ cb.record_failure()
+ cb.record_failure()
+ time.sleep(0.06)
+ cb.allow_request() # HALF_OPEN
+ cb.record_failure()
+ assert cb.state == CircuitState.OPEN