summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/pyproject.toml33
-rw-r--r--shared/src/shared/resilience.py127
-rw-r--r--shared/tests/test_resilience.py154
3 files changed, 294 insertions, 20 deletions
diff --git a/shared/pyproject.toml b/shared/pyproject.toml
index 830088d..dcddc84 100644
--- a/shared/pyproject.toml
+++ b/shared/pyproject.toml
@@ -4,28 +4,23 @@ version = "0.1.0"
description = "Shared models, events, and utilities for trading platform"
requires-python = ">=3.12"
dependencies = [
- "pydantic>=2.0",
- "pydantic-settings>=2.0",
- "redis>=5.0",
- "asyncpg>=0.29",
- "sqlalchemy[asyncio]>=2.0",
- "alembic>=1.13",
- "structlog>=24.0",
- "prometheus-client>=0.20",
- "pyyaml>=6.0",
- "aiohttp>=3.9",
- "rich>=13.0",
+ "pydantic>=2.8,<3",
+ "pydantic-settings>=2.0,<3",
+ "redis>=5.0,<6",
+ "asyncpg>=0.29,<1",
+ "sqlalchemy[asyncio]>=2.0,<3",
+ "alembic>=1.13,<2",
+ "structlog>=24.0,<25",
+ "prometheus-client>=0.20,<1",
+ "pyyaml>=6.0,<7",
+ "aiohttp>=3.9,<4",
+ "rich>=13.0,<14",
+ "tenacity>=8.2,<10",
]
[project.optional-dependencies]
-dev = [
- "pytest>=8.0",
- "pytest-asyncio>=0.23",
- "ruff>=0.4",
-]
-claude = [
- "anthropic>=0.40",
-]
+dev = ["pytest>=8.0,<9", "pytest-asyncio>=0.23,<1", "ruff>=0.4,<1"]
+claude = ["anthropic>=0.40,<1"]
[build-system]
requires = ["hatchling"]
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py
index 8d8678a..b18aaf7 100644
--- a/shared/src/shared/resilience.py
+++ b/shared/src/shared/resilience.py
@@ -1 +1,126 @@
-"""Resilience utilities for the trading platform."""
+"""Resilience utilities for the trading platform.
+
+Provides retry, circuit breaker, and timeout primitives using only stdlib.
+No external dependencies required.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import functools
+import logging
+import random
+import time
+from contextlib import asynccontextmanager
+from typing import Any, Callable
+
+logger = logging.getLogger(__name__)
+
+
+def retry_async(
+ max_retries: int = 3,
+ base_delay: float = 1.0,
+ max_delay: float = 30.0,
+ exclude: tuple[type[BaseException], ...] = (),
+) -> Callable:
+ """Decorator: exponential backoff + jitter for async functions.
+
+ Parameters:
+ max_retries: Maximum number of retry attempts (after the initial call).
+ base_delay: Base delay in seconds for exponential backoff.
+ max_delay: Maximum delay cap in seconds.
+ exclude: Exception types that should NOT be retried (raised immediately).
+ """
+
+ def decorator(func: Callable) -> Callable:
+ @functools.wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ last_exc: BaseException | None = None
+ for attempt in range(max_retries + 1):
+ try:
+ return await func(*args, **kwargs)
+ except Exception as exc:
+ if exclude and isinstance(exc, exclude):
+ raise
+ last_exc = exc
+ if attempt < max_retries:
+ delay = min(base_delay * (2**attempt), max_delay)
+ jitter_delay = delay * random.uniform(0.5, 1.0)
+ logger.warning(
+ "Retry %d/%d for %s in %.2fs: %s",
+ attempt + 1,
+ max_retries,
+ func.__name__,
+ jitter_delay,
+ exc,
+ )
+ await asyncio.sleep(jitter_delay)
+ raise last_exc # type: ignore[misc]
+
+ return wrapper
+
+ return decorator
+
+
+class CircuitBreaker:
+ """Circuit breaker: opens after N consecutive failures, auto-recovers.
+
+ States: closed -> open -> half_open -> closed
+
+ Parameters:
+ failure_threshold: Number of consecutive failures before opening.
+ cooldown: Seconds to wait before allowing a half-open probe.
+ """
+
+ def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None:
+ self.failure_threshold = failure_threshold
+ self.cooldown = cooldown
+ self._failures = 0
+ self._state = "closed"
+ self._opened_at: float = 0.0
+
+ async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
+ """Execute func through the breaker."""
+ if self._state == "open":
+ if time.monotonic() - self._opened_at >= self.cooldown:
+ self._state = "half_open"
+ else:
+ raise RuntimeError("Circuit breaker is open")
+
+ try:
+ result = await func(*args, **kwargs)
+ except Exception:
+ self._failures += 1
+ if self._state == "half_open":
+ self._state = "open"
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker re-opened after half-open probe failure (threshold=%d)",
+ self.failure_threshold,
+ )
+ elif self._failures >= self.failure_threshold:
+ self._state = "open"
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker opened after %d consecutive failures",
+ self._failures,
+ )
+ raise
+
+ # Success: reset
+ self._failures = 0
+ self._state = "closed"
+ return result
+
+
+@asynccontextmanager
+async def async_timeout(seconds: float):
+ """Async context manager wrapping asyncio.timeout().
+
+ Raises TimeoutError with a descriptive message on timeout.
+ """
+ try:
+ async with asyncio.timeout(seconds):
+ yield
+ except TimeoutError:
+ raise TimeoutError(f"Operation timed out after {seconds}s") from None
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
new file mode 100644
index 0000000..dde47e1
--- /dev/null
+++ b/shared/tests/test_resilience.py
@@ -0,0 +1,154 @@
+"""Tests for shared.resilience module."""
+
+import asyncio
+
+import pytest
+
+from shared.resilience import CircuitBreaker, async_timeout, retry_async
+
+
+# --- retry_async tests ---
+
+
+async def test_succeeds_without_retry():
+ """Function succeeds first try, called once."""
+ call_count = 0
+
+ @retry_async()
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ return "ok"
+
+ result = await fn()
+ assert result == "ok"
+ assert call_count == 1
+
+
+async def test_retries_on_failure_then_succeeds():
+ """Fails twice then succeeds, verify call count."""
+ call_count = 0
+
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ if call_count < 3:
+ raise RuntimeError("transient")
+ return "recovered"
+
+ result = await fn()
+ assert result == "recovered"
+ assert call_count == 3
+
+
+async def test_raises_after_max_retries():
+ """Always fails, raises after max retries."""
+ call_count = 0
+
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ raise ValueError("permanent")
+
+ with pytest.raises(ValueError, match="permanent"):
+ await fn()
+
+ # 1 initial + 3 retries = 4 total calls
+ assert call_count == 4
+
+
+async def test_no_retry_on_excluded_exception():
+ """Excluded exception raises immediately, call count = 1."""
+ call_count = 0
+
+ @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,))
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ raise TypeError("excluded")
+
+ with pytest.raises(TypeError, match="excluded"):
+ await fn()
+
+ assert call_count == 1
+
+
+# --- CircuitBreaker tests ---
+
+
+async def test_closed_allows_calls():
+ """CircuitBreaker in closed state passes through."""
+ cb = CircuitBreaker(failure_threshold=5, cooldown=60.0)
+
+ async def fn():
+ return "ok"
+
+ result = await cb.call(fn)
+ assert result == "ok"
+
+
+async def test_opens_after_threshold():
+ """After N failures, raises RuntimeError."""
+ cb = CircuitBreaker(failure_threshold=3, cooldown=60.0)
+
+ async def fail():
+ raise RuntimeError("fail")
+
+ for _ in range(3):
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+
+ # Now the breaker should be open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+
+async def test_half_open_after_cooldown():
+ """After cooldown, allows recovery attempt."""
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def fail():
+ raise RuntimeError("fail")
+
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+
+ # Breaker is open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+ # Wait for cooldown
+ await asyncio.sleep(0.06)
+
+ # Now should allow a call (half_open). Succeed to close it.
+ async def succeed():
+ return "recovered"
+
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+ # Breaker should be closed again
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+
+# --- async_timeout tests ---
+
+
+async def test_completes_within_timeout():
+ """async_timeout doesn't interfere with fast operations."""
+ async with async_timeout(1.0):
+ await asyncio.sleep(0.01)
+ result = 42
+ assert result == 42
+
+
+async def test_raises_on_timeout():
+ """async_timeout raises TimeoutError for slow operations."""
+ with pytest.raises(TimeoutError):
+ async with async_timeout(0.05):
+ await asyncio.sleep(1.0)