summaryrefslogtreecommitdiff
path: root/shared/src
diff options
context:
space:
mode:
Diffstat (limited to 'shared/src')
-rw-r--r--shared/src/shared/resilience.py127
1 files changed, 126 insertions, 1 deletions
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py
index 8d8678a..b18aaf7 100644
--- a/shared/src/shared/resilience.py
+++ b/shared/src/shared/resilience.py
@@ -1 +1,126 @@
-"""Resilience utilities for the trading platform."""
+"""Resilience utilities for the trading platform.
+
+Provides retry, circuit breaker, and timeout primitives using only stdlib.
+No external dependencies required.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import functools
+import logging
+import random
+import time
+from contextlib import asynccontextmanager
+from typing import Any, Callable
+
+logger = logging.getLogger(__name__)
+
+
+def retry_async(
+ max_retries: int = 3,
+ base_delay: float = 1.0,
+ max_delay: float = 30.0,
+ exclude: tuple[type[BaseException], ...] = (),
+) -> Callable:
+ """Decorator: exponential backoff + jitter for async functions.
+
+ Parameters:
+ max_retries: Maximum number of retry attempts (after the initial call).
+ base_delay: Base delay in seconds for exponential backoff.
+ max_delay: Maximum delay cap in seconds.
+ exclude: Exception types that should NOT be retried (raised immediately).
+ """
+
+ def decorator(func: Callable) -> Callable:
+ @functools.wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ last_exc: BaseException | None = None
+ for attempt in range(max_retries + 1):
+ try:
+ return await func(*args, **kwargs)
+ except Exception as exc:
+ if exclude and isinstance(exc, exclude):
+ raise
+ last_exc = exc
+ if attempt < max_retries:
+ delay = min(base_delay * (2**attempt), max_delay)
+ jitter_delay = delay * random.uniform(0.5, 1.0)
+ logger.warning(
+ "Retry %d/%d for %s in %.2fs: %s",
+ attempt + 1,
+ max_retries,
+ func.__name__,
+ jitter_delay,
+ exc,
+ )
+ await asyncio.sleep(jitter_delay)
+ raise last_exc # type: ignore[misc]
+
+ return wrapper
+
+ return decorator
+
+
+class CircuitBreaker:
+ """Circuit breaker: opens after N consecutive failures, auto-recovers.
+
+ States: closed -> open -> half_open -> closed
+
+ Parameters:
+ failure_threshold: Number of consecutive failures before opening.
+ cooldown: Seconds to wait before allowing a half-open probe.
+ """
+
+ def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None:
+ self.failure_threshold = failure_threshold
+ self.cooldown = cooldown
+ self._failures = 0
+ self._state = "closed"
+ self._opened_at: float = 0.0
+
+ async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
+ """Execute func through the breaker."""
+ if self._state == "open":
+ if time.monotonic() - self._opened_at >= self.cooldown:
+ self._state = "half_open"
+ else:
+ raise RuntimeError("Circuit breaker is open")
+
+ try:
+ result = await func(*args, **kwargs)
+ except Exception:
+ self._failures += 1
+ if self._state == "half_open":
+ self._state = "open"
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker re-opened after half-open probe failure (threshold=%d)",
+ self.failure_threshold,
+ )
+ elif self._failures >= self.failure_threshold:
+ self._state = "open"
+ self._opened_at = time.monotonic()
+ logger.error(
+ "Circuit breaker opened after %d consecutive failures",
+ self._failures,
+ )
+ raise
+
+ # Success: reset
+ self._failures = 0
+ self._state = "closed"
+ return result
+
+
+@asynccontextmanager
+async def async_timeout(seconds: float):
+ """Async context manager wrapping asyncio.timeout().
+
+ Raises TimeoutError with a descriptive message on timeout.
+ """
+ try:
+ async with asyncio.timeout(seconds):
+ yield
+ except TimeoutError:
+ raise TimeoutError(f"Operation timed out after {seconds}s") from None