summaryrefslogtreecommitdiff
path: root/shared/src/shared/healthcheck.py
blob: 8294294474d3f6683da8b073b968866b6748cd46 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Health check HTTP server with Prometheus metrics endpoint."""
from __future__ import annotations

import time
from typing import Any, Callable, Awaitable

from aiohttp import web
from prometheus_client import CollectorRegistry, REGISTRY, generate_latest, CONTENT_TYPE_LATEST


class HealthCheckServer:
    """Lightweight aiohttp server exposing /health and /metrics."""

    def __init__(
        self,
        service_name: str,
        port: int = 8080,
        *,
        registry: CollectorRegistry | None = None,
    ) -> None:
        self.service_name = service_name
        self.port = port
        self._checks: dict[str, Callable[[], Awaitable[bool]]] = {}
        self._start_time = time.monotonic()
        self._registry = registry or REGISTRY

    def register_check(self, name: str, check_fn: Callable[[], Awaitable[bool]]) -> None:
        """Register a named async health check function."""
        self._checks[name] = check_fn

    async def run_checks(self) -> dict[str, Any]:
        """Execute all registered checks and return a status dict."""
        checks: dict[str, str] = {}
        all_ok = True

        for name, fn in self._checks.items():
            try:
                result = await fn()
                if result:
                    checks[name] = "ok"
                else:
                    checks[name] = "fail"
                    all_ok = False
            except Exception as exc:
                checks[name] = f"fail: {exc}"
                all_ok = False

        return {
            "status": "ok" if all_ok else "degraded",
            "service": self.service_name,
            "uptime_seconds": round(time.monotonic() - self._start_time, 2),
            "checks": checks,
        }

    async def _handle_health(self, request: web.Request) -> web.Response:
        """GET /health — JSON health status."""
        result = await self.run_checks()
        status_code = 200 if result["status"] == "ok" else 503
        return web.json_response(result, status=status_code)

    async def _handle_metrics(self, request: web.Request) -> web.Response:
        """GET /metrics — Prometheus text exposition."""
        output = generate_latest(self._registry)
        return web.Response(body=output, content_type=CONTENT_TYPE_LATEST)

    async def start(self) -> web.AppRunner:
        """Create and start the aiohttp application, returning the runner."""
        app = web.Application()
        app.router.add_get("/health", self._handle_health)
        app.router.add_get("/metrics", self._handle_metrics)

        runner = web.AppRunner(app)
        await runner.setup()
        site = web.TCPSite(runner, "0.0.0.0", self.port)
        await site.start()
        return runner