summaryrefslogtreecommitdiff
path: root/shared/tests/test_resilience.py
blob: 514bcc282c43e64c46130a7b31b556cdc74c8613 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Tests for retry with backoff and circuit breaker."""
import asyncio
import time

import pytest

from shared.resilience import CircuitBreaker, CircuitState, retry_with_backoff


# ---------------------------------------------------------------------------
# retry_with_backoff tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_retry_succeeds_first_try():
    call_count = 0

    @retry_with_backoff(max_retries=3, base_delay=0.01)
    async def succeed():
        nonlocal call_count
        call_count += 1
        return "ok"

    result = await succeed()
    assert result == "ok"
    assert call_count == 1


@pytest.mark.asyncio
async def test_retry_succeeds_after_failures():
    call_count = 0

    @retry_with_backoff(max_retries=3, base_delay=0.01)
    async def flaky():
        nonlocal call_count
        call_count += 1
        if call_count < 3:
            raise ValueError("not yet")
        return "recovered"

    result = await flaky()
    assert result == "recovered"
    assert call_count == 3


@pytest.mark.asyncio
async def test_retry_raises_after_max_retries():
    call_count = 0

    @retry_with_backoff(max_retries=3, base_delay=0.01)
    async def always_fail():
        nonlocal call_count
        call_count += 1
        raise RuntimeError("permanent")

    with pytest.raises(RuntimeError, match="permanent"):
        await always_fail()
    # 1 initial + 3 retries = 4 calls
    assert call_count == 4


@pytest.mark.asyncio
async def test_retry_respects_max_delay():
    """Backoff should be capped at max_delay."""
    @retry_with_backoff(max_retries=2, base_delay=0.01, max_delay=0.02)
    async def always_fail():
        raise RuntimeError("fail")

    start = time.monotonic()
    with pytest.raises(RuntimeError):
        await always_fail()
    elapsed = time.monotonic() - start
    # With max_delay=0.02 and 2 retries, total delay should be small
    assert elapsed < 0.5


# ---------------------------------------------------------------------------
# CircuitBreaker tests
# ---------------------------------------------------------------------------


def test_circuit_starts_closed():
    cb = CircuitBreaker(failure_threshold=3, recovery_timeout=0.05)
    assert cb.state == CircuitState.CLOSED
    assert cb.allow_request() is True


def test_circuit_opens_after_threshold():
    cb = CircuitBreaker(failure_threshold=3, recovery_timeout=60.0)
    for _ in range(3):
        cb.record_failure()
    assert cb.state == CircuitState.OPEN
    assert cb.allow_request() is False


def test_circuit_rejects_when_open():
    cb = CircuitBreaker(failure_threshold=2, recovery_timeout=60.0)
    cb.record_failure()
    cb.record_failure()
    assert cb.state == CircuitState.OPEN
    assert cb.allow_request() is False


def test_circuit_half_open_after_timeout():
    cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
    cb.record_failure()
    cb.record_failure()
    assert cb.state == CircuitState.OPEN

    time.sleep(0.06)
    assert cb.allow_request() is True
    assert cb.state == CircuitState.HALF_OPEN


def test_circuit_closes_on_success():
    cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
    cb.record_failure()
    cb.record_failure()
    assert cb.state == CircuitState.OPEN

    time.sleep(0.06)
    cb.allow_request()  # triggers HALF_OPEN
    assert cb.state == CircuitState.HALF_OPEN

    cb.record_success()
    assert cb.state == CircuitState.CLOSED
    assert cb.allow_request() is True


def test_circuit_reopens_on_failure_in_half_open():
    cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.05)
    cb.record_failure()
    cb.record_failure()
    time.sleep(0.06)
    cb.allow_request()  # HALF_OPEN
    cb.record_failure()
    assert cb.state == CircuitState.OPEN