diff options
Diffstat (limited to 'shared')
| -rw-r--r-- | shared/pyproject.toml | 33 | ||||
| -rw-r--r-- | shared/src/shared/resilience.py | 127 | ||||
| -rw-r--r-- | shared/tests/test_resilience.py | 154 |
3 files changed, 294 insertions, 20 deletions
diff --git a/shared/pyproject.toml b/shared/pyproject.toml index 830088d..dcddc84 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -4,28 +4,23 @@ version = "0.1.0" description = "Shared models, events, and utilities for trading platform" requires-python = ">=3.12" dependencies = [ - "pydantic>=2.0", - "pydantic-settings>=2.0", - "redis>=5.0", - "asyncpg>=0.29", - "sqlalchemy[asyncio]>=2.0", - "alembic>=1.13", - "structlog>=24.0", - "prometheus-client>=0.20", - "pyyaml>=6.0", - "aiohttp>=3.9", - "rich>=13.0", + "pydantic>=2.8,<3", + "pydantic-settings>=2.0,<3", + "redis>=5.0,<6", + "asyncpg>=0.29,<1", + "sqlalchemy[asyncio]>=2.0,<3", + "alembic>=1.13,<2", + "structlog>=24.0,<25", + "prometheus-client>=0.20,<1", + "pyyaml>=6.0,<7", + "aiohttp>=3.9,<4", + "rich>=13.0,<14", + "tenacity>=8.2,<10", ] [project.optional-dependencies] -dev = [ - "pytest>=8.0", - "pytest-asyncio>=0.23", - "ruff>=0.4", -] -claude = [ - "anthropic>=0.40", -] +dev = ["pytest>=8.0,<9", "pytest-asyncio>=0.23,<1", "ruff>=0.4,<1"] +claude = ["anthropic>=0.40,<1"] [build-system] requires = ["hatchling"] diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py index 8d8678a..b18aaf7 100644 --- a/shared/src/shared/resilience.py +++ b/shared/src/shared/resilience.py @@ -1 +1,126 @@ -"""Resilience utilities for the trading platform.""" +"""Resilience utilities for the trading platform. + +Provides retry, circuit breaker, and timeout primitives using only stdlib. +No external dependencies required. +""" + +from __future__ import annotations + +import asyncio +import functools +import logging +import random +import time +from contextlib import asynccontextmanager +from typing import Any, Callable + +logger = logging.getLogger(__name__) + + +def retry_async( + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + exclude: tuple[type[BaseException], ...] = (), +) -> Callable: + """Decorator: exponential backoff + jitter for async functions. + + Parameters: + max_retries: Maximum number of retry attempts (after the initial call). + base_delay: Base delay in seconds for exponential backoff. + max_delay: Maximum delay cap in seconds. + exclude: Exception types that should NOT be retried (raised immediately). + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + last_exc: BaseException | None = None + for attempt in range(max_retries + 1): + try: + return await func(*args, **kwargs) + except Exception as exc: + if exclude and isinstance(exc, exclude): + raise + last_exc = exc + if attempt < max_retries: + delay = min(base_delay * (2**attempt), max_delay) + jitter_delay = delay * random.uniform(0.5, 1.0) + logger.warning( + "Retry %d/%d for %s in %.2fs: %s", + attempt + 1, + max_retries, + func.__name__, + jitter_delay, + exc, + ) + await asyncio.sleep(jitter_delay) + raise last_exc # type: ignore[misc] + + return wrapper + + return decorator + + +class CircuitBreaker: + """Circuit breaker: opens after N consecutive failures, auto-recovers. + + States: closed -> open -> half_open -> closed + + Parameters: + failure_threshold: Number of consecutive failures before opening. + cooldown: Seconds to wait before allowing a half-open probe. + """ + + def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None: + self.failure_threshold = failure_threshold + self.cooldown = cooldown + self._failures = 0 + self._state = "closed" + self._opened_at: float = 0.0 + + async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any: + """Execute func through the breaker.""" + if self._state == "open": + if time.monotonic() - self._opened_at >= self.cooldown: + self._state = "half_open" + else: + raise RuntimeError("Circuit breaker is open") + + try: + result = await func(*args, **kwargs) + except Exception: + self._failures += 1 + if self._state == "half_open": + self._state = "open" + self._opened_at = time.monotonic() + logger.error( + "Circuit breaker re-opened after half-open probe failure (threshold=%d)", + self.failure_threshold, + ) + elif self._failures >= self.failure_threshold: + self._state = "open" + self._opened_at = time.monotonic() + logger.error( + "Circuit breaker opened after %d consecutive failures", + self._failures, + ) + raise + + # Success: reset + self._failures = 0 + self._state = "closed" + return result + + +@asynccontextmanager +async def async_timeout(seconds: float): + """Async context manager wrapping asyncio.timeout(). + + Raises TimeoutError with a descriptive message on timeout. + """ + try: + async with asyncio.timeout(seconds): + yield + except TimeoutError: + raise TimeoutError(f"Operation timed out after {seconds}s") from None diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py new file mode 100644 index 0000000..dde47e1 --- /dev/null +++ b/shared/tests/test_resilience.py @@ -0,0 +1,154 @@ +"""Tests for shared.resilience module.""" + +import asyncio + +import pytest + +from shared.resilience import CircuitBreaker, async_timeout, retry_async + + +# --- retry_async tests --- + + +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" + call_count = 0 + + @retry_async() + async def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await fn() + assert result == "ok" + assert call_count == 1 + + +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError("transient") + return "recovered" + + result = await fn() + assert result == "recovered" + assert call_count == 3 + + +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): + nonlocal call_count + call_count += 1 + raise ValueError("permanent") + + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls + assert call_count == 4 + + +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") + + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 + + +# --- CircuitBreaker tests --- + + +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") + + for _ in range(3): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +# --- async_timeout tests --- + + +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 + + +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) |
