summaryrefslogtreecommitdiff
path: root/shared/tests/test_healthcheck.py
blob: 2f797577c1e6a5c8317ba64fed9fac4c7d7e87c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Tests for health check server."""

from unittest.mock import MagicMock

import pytest
from multidict import CIMultiDict
from prometheus_client import CollectorRegistry


@pytest.fixture
def registry():
    return CollectorRegistry()


def make_server(service_name="test-service", port=8080, registry=None, auth_token=""):
    from shared.healthcheck import HealthCheckServer

    return HealthCheckServer(service_name, port=port, auth_token=auth_token, registry=registry)


def _fake_request(headers: dict | None = None) -> MagicMock:
    """Create a minimal mock that quacks like aiohttp.web.Request."""
    req = MagicMock()
    req.headers = CIMultiDict(headers or {})
    return req


def test_init_defaults(registry):
    """HealthCheckServer initialises with service name and port."""
    server = make_server("my-service", registry=registry)
    assert server.service_name == "my-service"
    assert server.port == 8080
    assert server._checks == {}


def test_register_check(registry):
    """register_check stores an async callable."""
    server = make_server(registry=registry)

    async def check_redis():
        return True

    server.register_check("redis", check_redis)
    assert "redis" in server._checks
    assert server._checks["redis"] is check_redis


@pytest.mark.asyncio
async def test_run_checks_all_pass(registry):
    """run_checks returns 'ok' when all checks pass."""
    server = make_server(registry=registry)

    async def ok_check():
        return True

    server.register_check("db", ok_check)
    server.register_check("redis", ok_check)

    result = await server.run_checks()
    assert result["status"] == "ok"
    assert result["service"] == "test-service"
    assert "uptime_seconds" in result
    assert result["checks"]["db"] == "ok"
    assert result["checks"]["redis"] == "ok"


@pytest.mark.asyncio
async def test_run_checks_one_fails(registry):
    """run_checks returns 'degraded' when a check fails."""
    server = make_server(registry=registry)

    async def ok_check():
        return True

    async def bad_check():
        raise ConnectionError("down")

    server.register_check("db", ok_check)
    server.register_check("redis", bad_check)

    result = await server.run_checks()
    assert result["status"] == "degraded"
    assert result["checks"]["db"] == "ok"
    assert "fail" in result["checks"]["redis"]


@pytest.mark.asyncio
async def test_run_checks_false_is_fail(registry):
    """run_checks treats False return as failure."""
    server = make_server(registry=registry)

    async def false_check():
        return False

    server.register_check("cache", false_check)

    result = await server.run_checks()
    assert result["status"] == "degraded"
    assert result["checks"]["cache"] == "fail"


# ── Bearer-token auth tests ────────────────────────────────────────


def test_healthcheck_no_auth_when_token_empty(registry):
    """When auth_token is empty, all requests pass auth regardless of headers."""
    server = make_server(registry=registry, auth_token="")
    assert server._check_auth(_fake_request()) is True
    assert server._check_auth(_fake_request({"Authorization": "Bearer wrong"})) is True


def test_healthcheck_auth_required_when_token_set(registry):
    """When auth_token is set, a matching Bearer header passes auth."""
    server = make_server(registry=registry, auth_token="s3cret")
    req = _fake_request({"Authorization": "Bearer s3cret"})
    assert server._check_auth(req) is True


def test_healthcheck_rejects_wrong_token(registry):
    """When auth_token is set, a wrong or missing Bearer header is rejected."""
    server = make_server(registry=registry, auth_token="s3cret")
    # Wrong token
    assert server._check_auth(_fake_request({"Authorization": "Bearer bad"})) is False
    # Missing header entirely
    assert server._check_auth(_fake_request()) is False
    # Malformed header
    assert server._check_auth(_fake_request({"Authorization": "Token s3cret"})) is False