"""Tests for retry with backoff and circuit breaker.""" import asyncio import time import pytest from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff # --------------------------------------------------------------------------- # retry_with_backoff tests # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_retry_succeeds_first_try(): call_count = 0 @retry_with_backoff(max_retries=3, base_delay=0.01) async def succeed(): nonlocal call_count call_count += 1 return "ok" result = await succeed() assert result == "ok" assert call_count == 1 @pytest.mark.asyncio async def test_retry_succeeds_after_failures(): call_count = 0 @retry_with_backoff(max_retries=3, base_delay=0.01) async def flaky(): nonlocal call_count call_count += 1 if call_count < 3: raise ValueError("not yet") return "recovered" result = await flaky() assert result == "recovered" assert call_count == 3 @pytest.mark.asyncio async def test_retry_raises_after_max_retries(): call_count = 0 @retry_with_backoff(max_retries=3, base_delay=0.01) async def always_fail(): nonlocal call_count call_count += 1 raise RuntimeError("permanent") with pytest.raises(RuntimeError, match="permanent"): await always_fail() # 1 initial + 3 retries = 4 calls assert call_count == 4 @pytest.mark.asyncio async def test_retry_respects_max_delay(): """Backoff should be capped at max_delay.""" @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02) async def always_fail(): raise RuntimeError("fail") start = time.monotonic() with pytest.raises(RuntimeError): await always_fail() elapsed = time.monotonic() - start # With max_delay=0.02 and 2 retries, total delay should be small assert elapsed < 0.5 # --------------------------------------------------------------------------- # CircuitBreaker tests # --------------------------------------------------------------------------- def test_circuit_starts_closed(): cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05) assert cb.state == CircuitState.CLOSED assert cb.allow_request() is True def test_circuit_opens_after_threshold(): cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0) for _ in range(3): cb.record_failure() assert cb.state == CircuitState.OPEN assert cb.allow_request() is False def test_circuit_rejects_when_open(): cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0) cb.record_failure() cb.record_failure() assert cb.state == CircuitState.OPEN assert cb.allow_request() is False def test_circuit_half_open_after_timeout(): cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) cb.record_failure() cb.record_failure() assert cb.state == CircuitState.OPEN time.sleep(0.06) assert cb.allow_request() is True assert cb.state == CircuitState.HALF_OPEN def test_circuit_closes_on_success(): cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) cb.record_failure() cb.record_failure() assert cb.state == CircuitState.OPEN time.sleep(0.06) cb.allow_request() # triggers HALF_OPEN assert cb.state == CircuitState.HALF_OPEN cb.record_success() assert cb.state == CircuitState.CLOSED assert cb.allow_request() is True def test_circuit_reopens_on_failure_in_half_open(): cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05) cb.record_failure() cb.record_failure() time.sleep(0.06) cb.allow_request() # HALF_OPEN cb.record_failure() assert cb.state == CircuitState.OPEN