diff options
Diffstat (limited to 'shared/tests/test_resilience.py')
| -rw-r--r-- | shared/tests/test_resilience.py | 203 |
1 files changed, 120 insertions, 83 deletions
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py index e287777..e0781af 100644 --- a/shared/tests/test_resilience.py +++ b/shared/tests/test_resilience.py @@ -1,139 +1,176 @@ -"""Tests for retry with backoff and circuit breaker.""" +"""Tests for shared.resilience module.""" -import time +import asyncio import pytest -from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff +from shared.resilience import CircuitBreaker, async_timeout, retry_async +# --- retry_async tests --- -# --------------------------------------------------------------------------- -# retry_with_backoff tests -# --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_retry_succeeds_first_try(): +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def succeed(): + @retry_async() + async def fn(): nonlocal call_count call_count += 1 return "ok" - result = await succeed() + result = await fn() assert result == "ok" assert call_count == 1 -@pytest.mark.asyncio -async def test_retry_succeeds_after_failures(): +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def flaky(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 if call_count < 3: - raise ValueError("not yet") + raise RuntimeError("transient") return "recovered" - result = await flaky() + result = await fn() assert result == "recovered" assert call_count == 3 -@pytest.mark.asyncio -async def test_retry_raises_after_max_retries(): +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" call_count = 0 - @retry_with_backoff(max_retries=3, base_delay=0.01) - async def always_fail(): + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): nonlocal call_count call_count += 1 - raise RuntimeError("permanent") + raise ValueError("permanent") - with pytest.raises(RuntimeError, match="permanent"): - await always_fail() - # 1 initial + 3 retries = 4 calls + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls assert call_count == 4 -@pytest.mark.asyncio -async def test_retry_respects_max_delay(): - """Backoff should be capped at max_delay.""" +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 - @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) - async def always_fail(): - raise RuntimeError("fail") + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") - start = time.monotonic() - with pytest.raises(RuntimeError): - await always_fail() - elapsed = time.monotonic() - start - # With max_delay=0.02 and 2 retries, total delay should be small - assert elapsed < 0.5 + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 -# --------------------------------------------------------------------------- -# CircuitBreaker tests -# --------------------------------------------------------------------------- +# --- CircuitBreaker tests --- -def test_circuit_starts_closed(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") -def test_circuit_opens_after_threshold(): - cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) for _ in range(3): - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +async def test_half_open_reopens_on_failure(): + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def always_fail(): + raise ConnectionError("fail") -def test_circuit_rejects_when_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN - assert cb.allow_request() is False + # Trip the breaker + for _ in range(2): + with pytest.raises(ConnectionError): + await cb.call(always_fail) + # Wait for cooldown + await asyncio.sleep(0.1) -def test_circuit_half_open_after_timeout(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN + # Half-open probe should fail and re-open + with pytest.raises(ConnectionError): + await cb.call(always_fail) - time.sleep(0.06) - assert cb.allow_request() is True - assert cb.state == CircuitState.HALF_OPEN + # Should be open again (no cooldown wait) + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(always_fail) -def test_circuit_closes_on_success(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - assert cb.state == CircuitState.OPEN +# --- async_timeout tests --- - time.sleep(0.06) - cb.allow_request() # triggers HALF_OPEN - assert cb.state == CircuitState.HALF_OPEN - cb.record_success() - assert cb.state == CircuitState.CLOSED - assert cb.allow_request() is True +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 -def test_circuit_reopens_on_failure_in_half_open(): - cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) - cb.record_failure() - cb.record_failure() - time.sleep(0.06) - cb.allow_request() # HALF_OPEN - cb.record_failure() - assert cb.state == CircuitState.OPEN +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) |
