diff options
Diffstat (limited to 'shared/tests/test_resilience.py')
| -rw-r--r-- | shared/tests/test_resilience.py | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py new file mode 100644 index 0000000..514bcc2 --- /dev/null +++ b/shared/tests/test_resilience.py @@ -0,0 +1,138 @@ +"""Tests for retry with backoff and circuit breaker.""" +import asyncio +import time + +import pytest + +from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff + + +# --------------------------------------------------------------------------- +# retry_with_backoff tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_retry_succeeds_first_try(): + call_count = 0 + + @retry_with_backoff(max_retries=3, base_delay=0.01) + async def succeed(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await succeed() + assert result == "ok" + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_retry_succeeds_after_failures(): + call_count = 0 + + @retry_with_backoff(max_retries=3, base_delay=0.01) + async def flaky(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ValueError("not yet") + return "recovered" + + result = await flaky() + assert result == "recovered" + assert call_count == 3 + + +@pytest.mark.asyncio +async def test_retry_raises_after_max_retries(): + call_count = 0 + + @retry_with_backoff(max_retries=3, base_delay=0.01) + async def always_fail(): + nonlocal call_count + call_count += 1 + raise RuntimeError("permanent") + + with pytest.raises(RuntimeError, match="permanent"): + await always_fail() + # 1 initial + 3 retries = 4 calls + assert call_count == 4 + + +@pytest.mark.asyncio +async def test_retry_respects_max_delay(): + """Backoff should be capped at max_delay.""" + @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) + async def always_fail(): + raise RuntimeError("fail") + + start = time.monotonic() + with pytest.raises(RuntimeError): + await always_fail() + elapsed = time.monotonic() - start + # With max_delay=0.02 and 2 retries, total delay should be small + assert elapsed < 0.5 + + +# --------------------------------------------------------------------------- +# CircuitBreaker tests +# --------------------------------------------------------------------------- + + +def test_circuit_starts_closed(): + cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) + assert cb.state == CircuitState.CLOSED + assert cb.allow_request() is True + + +def test_circuit_opens_after_threshold(): + cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) + for _ in range(3): + cb.record_failure() + assert cb.state == CircuitState.OPEN + assert cb.allow_request() is False + + +def test_circuit_rejects_when_open(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + assert cb.allow_request() is False + + +def test_circuit_half_open_after_timeout(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + + time.sleep(0.06) + assert cb.allow_request() is True + assert cb.state == CircuitState.HALF_OPEN + + +def test_circuit_closes_on_success(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + + time.sleep(0.06) + cb.allow_request() # triggers HALF_OPEN + assert cb.state == CircuitState.HALF_OPEN + + cb.record_success() + assert cb.state == CircuitState.CLOSED + assert cb.allow_request() is True + + +def test_circuit_reopens_on_failure_in_half_open(): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) + cb.record_failure() + cb.record_failure() + time.sleep(0.06) + cb.allow_request() # HALF_OPEN + cb.record_failure() + assert cb.state == CircuitState.OPEN |
