summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/src/shared/healthcheck.py76
-rw-r--r--shared/src/shared/metrics.py41
-rw-r--r--shared/tests/test_healthcheck.py88
-rw-r--r--shared/tests/test_metrics.py51
4 files changed, 256 insertions, 0 deletions
diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py
new file mode 100644
index 0000000..8294294
--- /dev/null
+++ b/shared/src/shared/healthcheck.py
@@ -0,0 +1,76 @@
+"""Health check HTTP server with Prometheus metrics endpoint."""
+from __future__ import annotations
+
+import time
+from typing import Any, Callable, Awaitable
+
+from aiohttp import web
+from prometheus_client import CollectorRegistry, REGISTRY, generate_latest, CONTENT_TYPE_LATEST
+
+
+class HealthCheckServer:
+ """Lightweight aiohttp server exposing /health and /metrics."""
+
+ def __init__(
+ self,
+ service_name: str,
+ port: int = 8080,
+ *,
+ registry: CollectorRegistry | None = None,
+ ) -> None:
+ self.service_name = service_name
+ self.port = port
+ self._checks: dict[str, Callable[[], Awaitable[bool]]] = {}
+ self._start_time = time.monotonic()
+ self._registry = registry or REGISTRY
+
+ def register_check(self, name: str, check_fn: Callable[[], Awaitable[bool]]) -> None:
+ """Register a named async health check function."""
+ self._checks[name] = check_fn
+
+ async def run_checks(self) -> dict[str, Any]:
+ """Execute all registered checks and return a status dict."""
+ checks: dict[str, str] = {}
+ all_ok = True
+
+ for name, fn in self._checks.items():
+ try:
+ result = await fn()
+ if result:
+ checks[name] = "ok"
+ else:
+ checks[name] = "fail"
+ all_ok = False
+ except Exception as exc:
+ checks[name] = f"fail: {exc}"
+ all_ok = False
+
+ return {
+ "status": "ok" if all_ok else "degraded",
+ "service": self.service_name,
+ "uptime_seconds": round(time.monotonic() - self._start_time, 2),
+ "checks": checks,
+ }
+
+ async def _handle_health(self, request: web.Request) -> web.Response:
+ """GET /health — JSON health status."""
+ result = await self.run_checks()
+ status_code = 200 if result["status"] == "ok" else 503
+ return web.json_response(result, status=status_code)
+
+ async def _handle_metrics(self, request: web.Request) -> web.Response:
+ """GET /metrics — Prometheus text exposition."""
+ output = generate_latest(self._registry)
+ return web.Response(body=output, content_type=CONTENT_TYPE_LATEST)
+
+ async def start(self) -> web.AppRunner:
+ """Create and start the aiohttp application, returning the runner."""
+ app = web.Application()
+ app.router.add_get("/health", self._handle_health)
+ app.router.add_get("/metrics", self._handle_metrics)
+
+ runner = web.AppRunner(app)
+ await runner.setup()
+ site = web.TCPSite(runner, "0.0.0.0", self.port)
+ await site.start()
+ return runner
diff --git a/shared/src/shared/metrics.py b/shared/src/shared/metrics.py
new file mode 100644
index 0000000..3b00c5d
--- /dev/null
+++ b/shared/src/shared/metrics.py
@@ -0,0 +1,41 @@
+"""Prometheus metrics for trading platform services."""
+from __future__ import annotations
+
+from prometheus_client import Counter, Gauge, Histogram, CollectorRegistry, REGISTRY
+
+
+class ServiceMetrics:
+ """Creates Prometheus metrics with a service-name prefix."""
+
+ def __init__(self, service_name: str, *, registry: CollectorRegistry | None = None) -> None:
+ self.service_name = service_name.replace("-", "_")
+ reg = registry or REGISTRY
+ prefix = self.service_name
+
+ self.errors_total = Counter(
+ f"{prefix}_errors_total",
+ "Total error count",
+ labelnames=["service", "error_type"],
+ registry=reg,
+ )
+
+ self.events_processed = Counter(
+ f"{prefix}_events_processed_total",
+ "Total events processed",
+ labelnames=["service", "event_type"],
+ registry=reg,
+ )
+
+ self.processing_seconds = Histogram(
+ f"{prefix}_processing_seconds",
+ "Processing duration in seconds",
+ labelnames=["service"],
+ registry=reg,
+ )
+
+ self.service_up = Gauge(
+ f"{prefix}_service_up",
+ "Whether the service is up (1) or down (0)",
+ labelnames=["service"],
+ registry=reg,
+ )
diff --git a/shared/tests/test_healthcheck.py b/shared/tests/test_healthcheck.py
new file mode 100644
index 0000000..1af86b1
--- /dev/null
+++ b/shared/tests/test_healthcheck.py
@@ -0,0 +1,88 @@
+"""Tests for health check server."""
+import pytest
+import asyncio
+from prometheus_client import CollectorRegistry
+
+
+@pytest.fixture
+def registry():
+ return CollectorRegistry()
+
+
+def make_server(service_name="test-service", port=8080, registry=None):
+ from shared.healthcheck import HealthCheckServer
+ return HealthCheckServer(service_name, port=port, registry=registry)
+
+
+def test_init_defaults(registry):
+ """HealthCheckServer initialises with service name and port."""
+ server = make_server("my-service", registry=registry)
+ assert server.service_name == "my-service"
+ assert server.port == 8080
+ assert server._checks == {}
+
+
+def test_register_check(registry):
+ """register_check stores an async callable."""
+ server = make_server(registry=registry)
+
+ async def check_redis():
+ return True
+
+ server.register_check("redis", check_redis)
+ assert "redis" in server._checks
+ assert server._checks["redis"] is check_redis
+
+
+@pytest.mark.asyncio
+async def test_run_checks_all_pass(registry):
+ """run_checks returns 'ok' when all checks pass."""
+ server = make_server(registry=registry)
+
+ async def ok_check():
+ return True
+
+ server.register_check("db", ok_check)
+ server.register_check("redis", ok_check)
+
+ result = await server.run_checks()
+ assert result["status"] == "ok"
+ assert result["service"] == "test-service"
+ assert "uptime_seconds" in result
+ assert result["checks"]["db"] == "ok"
+ assert result["checks"]["redis"] == "ok"
+
+
+@pytest.mark.asyncio
+async def test_run_checks_one_fails(registry):
+ """run_checks returns 'degraded' when a check fails."""
+ server = make_server(registry=registry)
+
+ async def ok_check():
+ return True
+
+ async def bad_check():
+ raise ConnectionError("down")
+
+ server.register_check("db", ok_check)
+ server.register_check("redis", bad_check)
+
+ result = await server.run_checks()
+ assert result["status"] == "degraded"
+ assert result["checks"]["db"] == "ok"
+ assert "fail" in result["checks"]["redis"]
+
+
+@pytest.mark.asyncio
+async def test_run_checks_false_is_fail(registry):
+ """run_checks treats False return as failure."""
+ server = make_server(registry=registry)
+
+ async def false_check():
+ return False
+
+ server.register_check("cache", false_check)
+
+ result = await server.run_checks()
+ assert result["status"] == "degraded"
+ assert result["checks"]["cache"] == "fail"
diff --git a/shared/tests/test_metrics.py b/shared/tests/test_metrics.py
new file mode 100644
index 0000000..079f01c
--- /dev/null
+++ b/shared/tests/test_metrics.py
@@ -0,0 +1,51 @@
+"""Tests for Prometheus metrics utilities."""
+import pytest
+from prometheus_client import CollectorRegistry
+
+
+def make_metrics(service_name="test-service", registry=None):
+ from shared.metrics import ServiceMetrics
+ return ServiceMetrics(service_name, registry=registry)
+
+
+def test_metrics_creates_with_prefix():
+ """ServiceMetrics stores sanitized service name prefix."""
+ registry = CollectorRegistry()
+ m = make_metrics("my-service", registry=registry)
+ assert m.service_name == "my_service"
+
+
+def test_errors_total_increment():
+ """errors_total counter can be incremented with labels."""
+ registry = CollectorRegistry()
+ m = make_metrics("test-svc", registry=registry)
+ m.errors_total.labels(service="test_svc", error_type="timeout").inc()
+ assert m.errors_total.labels(service="test_svc", error_type="timeout")._value.get() == 1.0
+
+
+def test_events_processed_increment():
+ """events_processed counter can be incremented with labels."""
+ registry = CollectorRegistry()
+ m = make_metrics("test-svc", registry=registry)
+ m.events_processed.labels(service="test_svc", event_type="candle").inc(5)
+ assert m.events_processed.labels(service="test_svc", event_type="candle")._value.get() == 5.0
+
+
+def test_processing_seconds_observe():
+ """processing_seconds histogram can observe values."""
+ registry = CollectorRegistry()
+ m = make_metrics("test-svc", registry=registry)
+ m.processing_seconds.labels(service="test_svc").observe(0.5)
+ m.processing_seconds.labels(service="test_svc").observe(1.5)
+ # Sum should be 2.0
+ assert m.processing_seconds.labels(service="test_svc")._sum.get() == 2.0
+
+
+def test_service_up_gauge():
+ """service_up gauge can be set."""
+ registry = CollectorRegistry()
+ m = make_metrics("test-svc", registry=registry)
+ m.service_up.labels(service="test_svc").set(1)
+ assert m.service_up.labels(service="test_svc")._value.get() == 1.0
+ m.service_up.labels(service="test_svc").set(0)
+ assert m.service_up.labels(service="test_svc")._value.get() == 0.0