summaryrefslogtreecommitdiff
path: root/shared/tests/test_resilience.py
diff options
context:
space:
mode:
Diffstat (limited to 'shared/tests/test_resilience.py')
-rw-r--r--shared/tests/test_resilience.py203
1 files changed, 120 insertions, 83 deletions
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py
index e287777..e0781af 100644
--- a/shared/tests/test_resilience.py
+++ b/shared/tests/test_resilience.py
@@ -1,139 +1,176 @@
-"""Tests for retry with backoff and circuit breaker."""
+"""Tests for shared.resilience module."""
-import time
+import asyncio
import pytest
-from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff
+from shared.resilience import CircuitBreaker, async_timeout, retry_async
+# --- retry_async tests ---
-# ---------------------------------------------------------------------------
-# retry_with_backoff tests
-# ---------------------------------------------------------------------------
-
-@pytest.mark.asyncio
-async def test_retry_succeeds_first_try():
+async def test_succeeds_without_retry():
+ """Function succeeds first try, called once."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def succeed():
+ @retry_async()
+ async def fn():
nonlocal call_count
call_count += 1
return "ok"
- result = await succeed()
+ result = await fn()
assert result == "ok"
assert call_count == 1
-@pytest.mark.asyncio
-async def test_retry_succeeds_after_failures():
+async def test_retries_on_failure_then_succeeds():
+ """Fails twice then succeeds, verify call count."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def flaky():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
if call_count < 3:
- raise ValueError("not yet")
+ raise RuntimeError("transient")
return "recovered"
- result = await flaky()
+ result = await fn()
assert result == "recovered"
assert call_count == 3
-@pytest.mark.asyncio
-async def test_retry_raises_after_max_retries():
+async def test_raises_after_max_retries():
+ """Always fails, raises after max retries."""
call_count = 0
- @retry_with_backoff(max_retries=3, base_delay=0.01)
- async def always_fail():
+ @retry_async(max_retries=3, base_delay=0.01)
+ async def fn():
nonlocal call_count
call_count += 1
- raise RuntimeError("permanent")
+ raise ValueError("permanent")
- with pytest.raises(RuntimeError, match="permanent"):
- await always_fail()
- # 1 initial + 3 retries = 4 calls
+ with pytest.raises(ValueError, match="permanent"):
+ await fn()
+
+ # 1 initial + 3 retries = 4 total calls
assert call_count == 4
-@pytest.mark.asyncio
-async def test_retry_respects_max_delay():
- """Backoff should be capped at max_delay."""
+async def test_no_retry_on_excluded_exception():
+ """Excluded exception raises immediately, call count = 1."""
+ call_count = 0
- @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02)
- async def always_fail():
- raise RuntimeError("fail")
+ @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,))
+ async def fn():
+ nonlocal call_count
+ call_count += 1
+ raise TypeError("excluded")
- start = time.monotonic()
- with pytest.raises(RuntimeError):
- await always_fail()
- elapsed = time.monotonic() - start
- # With max_delay=0.02 and 2 retries, total delay should be small
- assert elapsed < 0.5
+ with pytest.raises(TypeError, match="excluded"):
+ await fn()
+
+ assert call_count == 1
-# ---------------------------------------------------------------------------
-# CircuitBreaker tests
-# ---------------------------------------------------------------------------
+# --- CircuitBreaker tests ---
-def test_circuit_starts_closed():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05)
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_closed_allows_calls():
+ """CircuitBreaker in closed state passes through."""
+ cb = CircuitBreaker(failure_threshold=5, cooldown=60.0)
+ async def fn():
+ return "ok"
+
+ result = await cb.call(fn)
+ assert result == "ok"
+
+
+async def test_opens_after_threshold():
+ """After N failures, raises RuntimeError."""
+ cb = CircuitBreaker(failure_threshold=3, cooldown=60.0)
+
+ async def fail():
+ raise RuntimeError("fail")
-def test_circuit_opens_after_threshold():
- cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0)
for _ in range(3):
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+ # Now the breaker should be open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+
+async def test_half_open_after_cooldown():
+ """After cooldown, allows recovery attempt."""
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def fail():
+ raise RuntimeError("fail")
+
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(RuntimeError, match="fail"):
+ await cb.call(fail)
+
+ # Breaker is open
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(fail)
+
+ # Wait for cooldown
+ await asyncio.sleep(0.06)
+
+ # Now should allow a call (half_open). Succeed to close it.
+ async def succeed():
+ return "recovered"
+
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+ # Breaker should be closed again
+ result = await cb.call(succeed)
+ assert result == "recovered"
+
+
+async def test_half_open_reopens_on_failure():
+ cb = CircuitBreaker(failure_threshold=2, cooldown=0.05)
+
+ async def always_fail():
+ raise ConnectionError("fail")
-def test_circuit_rejects_when_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
- assert cb.allow_request() is False
+ # Trip the breaker
+ for _ in range(2):
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
+ # Wait for cooldown
+ await asyncio.sleep(0.1)
-def test_circuit_half_open_after_timeout():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+ # Half-open probe should fail and re-open
+ with pytest.raises(ConnectionError):
+ await cb.call(always_fail)
- time.sleep(0.06)
- assert cb.allow_request() is True
- assert cb.state == CircuitState.HALF_OPEN
+ # Should be open again (no cooldown wait)
+ with pytest.raises(RuntimeError, match="Circuit breaker is open"):
+ await cb.call(always_fail)
-def test_circuit_closes_on_success():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+# --- async_timeout tests ---
- time.sleep(0.06)
- cb.allow_request() # triggers HALF_OPEN
- assert cb.state == CircuitState.HALF_OPEN
- cb.record_success()
- assert cb.state == CircuitState.CLOSED
- assert cb.allow_request() is True
+async def test_completes_within_timeout():
+ """async_timeout doesn't interfere with fast operations."""
+ async with async_timeout(1.0):
+ await asyncio.sleep(0.01)
+ result = 42
+ assert result == 42
-def test_circuit_reopens_on_failure_in_half_open():
- cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
- cb.record_failure()
- cb.record_failure()
- time.sleep(0.06)
- cb.allow_request() # HALF_OPEN
- cb.record_failure()
- assert cb.state == CircuitState.OPEN
+async def test_raises_on_timeout():
+ """async_timeout raises TimeoutError for slow operations."""
+ with pytest.raises(TimeoutError):
+ async with async_timeout(0.05):
+ await asyncio.sleep(1.0)