diff options
| author | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:27:08 +0900 |
|---|---|---|
| committer | TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> | 2026-04-02 15:27:08 +0900 |
| commit | 24cb407ef15f1997e2577c58e139b20d3986ed5b (patch) | |
| tree | aba9162b641e50ea0814995be0c404808980e25e /shared/tests | |
| parent | 8c852c5583b4f610844fe6001309b71e1958908e (diff) | |
feat: implement resilience module (retry, circuit breaker, timeout)
Add retry_async decorator with exponential backoff + jitter,
CircuitBreaker class with closed/open/half_open states, and
async_timeout context manager. Pin all shared deps with upper bounds.
Diffstat (limited to 'shared/tests')
| -rw-r--r-- | shared/tests/test_resilience.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py new file mode 100644 index 0000000..dde47e1 --- /dev/null +++ b/shared/tests/test_resilience.py @@ -0,0 +1,154 @@ +"""Tests for shared.resilience module.""" + +import asyncio + +import pytest + +from shared.resilience import CircuitBreaker, async_timeout, retry_async + + +# --- retry_async tests --- + + +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" + call_count = 0 + + @retry_async() + async def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await fn() + assert result == "ok" + assert call_count == 1 + + +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError("transient") + return "recovered" + + result = await fn() + assert result == "recovered" + assert call_count == 3 + + +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): + nonlocal call_count + call_count += 1 + raise ValueError("permanent") + + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls + assert call_count == 4 + + +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") + + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 + + +# --- CircuitBreaker tests --- + + +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") + + for _ in range(3): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +# --- async_timeout tests --- + + +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 + + +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) |
