summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
authorTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-02 15:31:09 +0900
committerTheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com>2026-04-02 15:31:09 +0900
commit0e7fd5059e8a813ccffe2c376b1ff43898b4d966 (patch)
tree3a4da6bcf61637e1e94907048c7110d7d3daabee /shared
parent24cb407ef15f1997e2577c58e139b20d3986ed5b (diff)
fix: address code review issues in resilience module
Diffstat (limited to 'shared')
-rw-r--r--shared/pyproject.toml1
-rw-r--r--shared/src/shared/resilience.py31
-rw-r--r--shared/tests/test_resilience.py23
3 files changed, 42 insertions, 13 deletions
diff --git a/shared/pyproject.toml b/shared/pyproject.toml
index dcddc84..eb74a11 100644
--- a/shared/pyproject.toml
+++ b/shared/pyproject.toml
@@ -15,7 +15,6 @@ dependencies = [
"pyyaml>=6.0,<7",
"aiohttp>=3.9,<4",
"rich>=13.0,<14",
- "tenacity>=8.2,<10",
]
[project.optional-dependencies]
diff --git a/shared/src/shared/resilience.py b/shared/src/shared/resilience.py
index b18aaf7..ef2a1f6 100644
--- a/shared/src/shared/resilience.py
+++ b/shared/src/shared/resilience.py
@@ -12,8 +12,15 @@ import logging
import random
import time
from contextlib import asynccontextmanager
+from enum import StrEnum
from typing import Any, Callable
+
+class _State(StrEnum):
+ CLOSED = "closed"
+ OPEN = "open"
+ HALF_OPEN = "half_open"
+
logger = logging.getLogger(__name__)
@@ -73,17 +80,17 @@ class CircuitBreaker:
"""
def __init__(self, failure_threshold: int = 5, cooldown: float = 60.0) -> None:
- self.failure_threshold = failure_threshold
- self.cooldown = cooldown
+ self._failure_threshold = failure_threshold
+ self._cooldown = cooldown
self._failures = 0
- self._state = "closed"
+ self._state = _State.CLOSED
self._opened_at: float = 0.0
async def call(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
"""Execute func through the breaker."""
- if self._state == "open":
- if time.monotonic() - self._opened_at >= self.cooldown:
- self._state = "half_open"
+ if self._state == _State.OPEN:
+ if time.monotonic() - self._opened_at >= self._cooldown:
+ self._state = _State.HALF_OPEN
else:
raise RuntimeError("Circuit breaker is open")
@@ -91,15 +98,15 @@ class CircuitBreaker:
result = await func(*args, **kwargs)
except Exception:
self._failures += 1
- if self._state == "half_open":
- self._state = "open"
+ if self._state == _State.HALF_OPEN:
+ self._state = _State.OPEN
self._opened_at = time.monotonic()
logger.error(
"Circuit breaker re-opened after half-open probe failure (threshold=%d)",
- self.failure_threshold,
+ self._failure_threshold,
)
- elif self._failures >= self.failure_threshold:
- self._state = "open"
+ elif self._failures >= self._failure_threshold:
+ self._state = _State.OPEN
self._opened_at = time.monotonic()
logger.error(
"Circuit breaker opened after %d consecutive failures",
@@ -109,7 +116,7 @@ class CircuitBreaker:
# Success: reset
self._failures = 0
- self._state = "closed"
+ self._state = _State.CLOSED
return result
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
index dde47e1..5ed4ac3 100644
--- a/shared/tests/test_resilience.py
+++ b/shared/tests/test_resilience.py
@@ -136,6 +136,29 @@ async def test_half_open_after_cooldown():
assert result == "recovered"
+async def test_half_open_reopens_on_failure():
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def always_fail():
+ raise ConnectionError("fail")
+
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
+
+ # Wait for cooldown
+ await asyncio.sleep(0.1)
+
+ # Half-open probe should fail and re-open
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
+
+ # Should be open again (no cooldown wait)
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(always_fail)
+
+
# --- async_timeout tests ---