diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:31:09 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:31:09 +0900 |
| commit | 0e7fd5059e8a813ccffe2c376b1ff43898b4d966 (patch) | |
| tree | 3a4da6bcf61637e1e94907048c7110d7d3daabee /shared | |
| parent | 24cb407ef15f1997e2577c58e139b20d3986ed5b (diff) | |
fix: address code review issues in resilience module
Diffstat (limited to 'shared')
| -rw-r--r-- | shared/pyproject.toml | 1 | ||||
| -rw-r--r-- | shared/src/shared/resilience.py | 31 | ||||
| -rw-r--r-- | shared/tests/test_resilience.py | 23 |
3 files changed, 42 insertions, 13 deletions
diff --git a/shared/pyproject.toml b/shared/pyproject.toml index dcddc84..eb74a11 100644 --- a/shared/pyproject.toml +++ b/shared/pyproject.toml @@ -15,7 +15,6 @@ dependencies = [ "pyyaml>=6.0,<7", "aiohttp>=3.9,<4", "rich>=13.0,<14", - "tenacity>=8.2,<10", ] [project.optional-dependencies] diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py index b18aaf7..ef2a1f6 100644 --- a/shared/src/shared/resilience.py +++ b/shared/src/shared/resilience.py @@ -12,8 +12,15 @@ import logging import random import time from contextlib import asynccontextmanager +from enum import StrEnum from typing import Any, Callable + +class _State(StrEnum): + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" + logger = logging.getLogger(__name__) @@ -73,17 +80,17 @@ class CircuitBreaker: """ def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None: - self.failure_threshold = failure_threshold - self.cooldown = cooldown + self._failure_threshold = failure_threshold + self._cooldown = cooldown self._failures = 0 - self._state = "closed" + self._state = _State.CLOSED self._opened_at: float = 0.0 async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any: """Execute func through the breaker.""" - if self._state == "open": - if time.monotonic() - self._opened_at >= self.cooldown: - self._state = "half_open" + if self._state == _State.OPEN: + if time.monotonic() - self._opened_at >= self._cooldown: + self._state = _State.HALF_OPEN else: raise RuntimeError("Circuit breaker is open") @@ -91,15 +98,15 @@ class CircuitBreaker: result = await func(*args, **kwargs) except Exception: self._failures += 1 - if self._state == "half_open": - self._state = "open" + if self._state == _State.HALF_OPEN: + self._state = _State.OPEN self._opened_at = time.monotonic() logger.error( "Circuit breaker re-opened after half-open probe failure (threshold=%d)", - self.failure_threshold, + self._failure_threshold, ) - elif self._failures >= self.failure_threshold: - self._state = "open" + elif self._failures >= self._failure_threshold: + self._state = _State.OPEN self._opened_at = time.monotonic() logger.error( "Circuit breaker opened after %d consecutive failures", @@ -109,7 +116,7 @@ class CircuitBreaker: # Success: reset self._failures = 0 - self._state = "closed" + self._state = _State.CLOSED return result diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py index dde47e1..5ed4ac3 100644 --- a/shared/tests/test_resilience.py +++ b/shared/tests/test_resilience.py @@ -136,6 +136,29 @@ async def test_half_open_after_cooldown(): assert result == "recovered" +async def test_half_open_reopens_on_failure(): + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def always_fail(): + raise ConnectionError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(always_fail) + + # Wait for cooldown + await asyncio.sleep(0.1) + + # Half-open probe should fail and re-open + with pytest.raises(ConnectionError): + await cb.call(always_fail) + + # Should be open again (no cooldown wait) + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(always_fail) + + # --- async_timeout tests --- |
