"""Tests for health check server.""" from unittest.mock import MagicMock import pytest from multidict import CIMultiDict from prometheus_client import CollectorRegistry @pytest.fixture def registry(): return CollectorRegistry() def make_server(service_name="test-service", port=8080, registry=None, auth_token=""): from shared.healthcheck import HealthCheckServer return HealthCheckServer(service_name, port=port, auth_token=auth_token, registry=registry) def _fake_request(headers: dict | None = None) -> MagicMock: """Create a minimal mock that quacks like aiohttp.web.Request.""" req = MagicMock() req.headers = CIMultiDict(headers or {}) return req def test_init_defaults(registry): """HealthCheckServer initialises with service name and port.""" server = make_server("my-service", registry=registry) assert server.service_name == "my-service" assert server.port == 8080 assert server._checks == {} def test_register_check(registry): """register_check stores an async callable.""" server = make_server(registry=registry) async def check_redis(): return True server.register_check("redis", check_redis) assert "redis" in server._checks assert server._checks["redis"] is check_redis @pytest.mark.asyncio async def test_run_checks_all_pass(registry): """run_checks returns 'ok' when all checks pass.""" server = make_server(registry=registry) async def ok_check(): return True server.register_check("db", ok_check) server.register_check("redis", ok_check) result = await server.run_checks() assert result["status"] == "ok" assert result["service"] == "test-service" assert "uptime_seconds" in result assert result["checks"]["db"] == "ok" assert result["checks"]["redis"] == "ok" @pytest.mark.asyncio async def test_run_checks_one_fails(registry): """run_checks returns 'degraded' when a check fails.""" server = make_server(registry=registry) async def ok_check(): return True async def bad_check(): raise ConnectionError("down") server.register_check("db", ok_check) server.register_check("redis", bad_check) result = await server.run_checks() assert result["status"] == "degraded" assert result["checks"]["db"] == "ok" assert "fail" in result["checks"]["redis"] @pytest.mark.asyncio async def test_run_checks_false_is_fail(registry): """run_checks treats False return as failure.""" server = make_server(registry=registry) async def false_check(): return False server.register_check("cache", false_check) result = await server.run_checks() assert result["status"] == "degraded" assert result["checks"]["cache"] == "fail" # ── Bearer-token auth tests ──────────────────────────────────────── def test_healthcheck_no_auth_when_token_empty(registry): """When auth_token is empty, all requests pass auth regardless of headers.""" server = make_server(registry=registry, auth_token="") assert server._check_auth(_fake_request()) is True assert server._check_auth(_fake_request({"Authorization": "Bearer wrong"})) is True def test_healthcheck_auth_required_when_token_set(registry): """When auth_token is set, a matching Bearer header passes auth.""" server = make_server(registry=registry, auth_token="s3cret") req = _fake_request({"Authorization": "Bearer s3cret"}) assert server._check_auth(req) is True def test_healthcheck_rejects_wrong_token(registry): """When auth_token is set, a wrong or missing Bearer header is rejected.""" server = make_server(registry=registry, auth_token="s3cret") # Wrong token assert server._check_auth(_fake_request({"Authorization": "Bearer bad"})) is False # Missing header entirely assert server._check_auth(_fake_request()) is False # Malformed header assert server._check_auth(_fake_request({"Authorization": "Token s3cret"})) is False