From 24cb407ef15f1997e2577c58e139b20d3986ed5b Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:27:08 +0900 Subject: feat: implement resilience module (retry, circuit breaker, timeout) Add retry_async decorator with exponential backoff + jitter, CircuitBreaker class with closed/open/half_open states, and async_timeout context manager. Pin all shared deps with upper bounds. --- shared/tests/test_resilience.py | 154 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 shared/tests/test_resilience.py (limited to 'shared/tests') diff --git a/shared/tests/test_resilience.py b/shared/tests/test_resilience.py new file mode 100644 index 0000000..dde47e1 --- /dev/null +++ b/shared/tests/test_resilience.py @@ -0,0 +1,154 @@ +"""Tests for shared.resilience module.""" + +import asyncio + +import pytest + +from shared.resilience import CircuitBreaker, async_timeout, retry_async + + +# --- retry_async tests --- + + +async def test_succeeds_without_retry(): + """Function succeeds first try, called once.""" + call_count = 0 + + @retry_async() + async def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await fn() + assert result == "ok" + assert call_count == 1 + + +async def test_retries_on_failure_then_succeeds(): + """Fails twice then succeeds, verify call count.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise RuntimeError("transient") + return "recovered" + + result = await fn() + assert result == "recovered" + assert call_count == 3 + + +async def test_raises_after_max_retries(): + """Always fails, raises after max retries.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01) + async def fn(): + nonlocal call_count + call_count += 1 + raise ValueError("permanent") + + with pytest.raises(ValueError, match="permanent"): + await fn() + + # 1 initial + 3 retries = 4 total calls + assert call_count == 4 + + +async def test_no_retry_on_excluded_exception(): + """Excluded exception raises immediately, call count = 1.""" + call_count = 0 + + @retry_async(max_retries=3, base_delay=0.01, exclude=(TypeError,)) + async def fn(): + nonlocal call_count + call_count += 1 + raise TypeError("excluded") + + with pytest.raises(TypeError, match="excluded"): + await fn() + + assert call_count == 1 + + +# --- CircuitBreaker tests --- + + +async def test_closed_allows_calls(): + """CircuitBreaker in closed state passes through.""" + cb = CircuitBreaker(failure_threshold=5, cooldown=60.0) + + async def fn(): + return "ok" + + result = await cb.call(fn) + assert result == "ok" + + +async def test_opens_after_threshold(): + """After N failures, raises RuntimeError.""" + cb = CircuitBreaker(failure_threshold=3, cooldown=60.0) + + async def fail(): + raise RuntimeError("fail") + + for _ in range(3): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Now the breaker should be open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + +async def test_half_open_after_cooldown(): + """After cooldown, allows recovery attempt.""" + cb = CircuitBreaker(failure_threshold=2, cooldown=0.05) + + async def fail(): + raise RuntimeError("fail") + + # Trip the breaker + for _ in range(2): + with pytest.raises(RuntimeError, match="fail"): + await cb.call(fail) + + # Breaker is open + with pytest.raises(RuntimeError, match="Circuit breaker is open"): + await cb.call(fail) + + # Wait for cooldown + await asyncio.sleep(0.06) + + # Now should allow a call (half_open). Succeed to close it. + async def succeed(): + return "recovered" + + result = await cb.call(succeed) + assert result == "recovered" + + # Breaker should be closed again + result = await cb.call(succeed) + assert result == "recovered" + + +# --- async_timeout tests --- + + +async def test_completes_within_timeout(): + """async_timeout doesn't interfere with fast operations.""" + async with async_timeout(1.0): + await asyncio.sleep(0.01) + result = 42 + assert result == 42 + + +async def test_raises_on_timeout(): + """async_timeout raises TimeoutError for slow operations.""" + with pytest.raises(TimeoutError): + async with async_timeout(0.05): + await asyncio.sleep(1.0) -- cgit v1.2.3