From 4e6ae373b6abc7ef0d5fb810385d14250757f3f1 Mon Sep 17 00:00:00 2001 From: TheSiahxyz <164138827+TheSiahxyz@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:22:44 +0900 Subject: feat(security): add bearer token auth for health/metrics endpoints --- shared/src/shared/config.py | 1 + shared/src/shared/healthcheck.py | 13 +++++++++++++ shared/tests/test_healthcheck.py | 42 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 2 deletions(-) (limited to 'shared') diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py index 47bc2b1..7b34d78 100644 --- a/shared/src/shared/config.py +++ b/shared/src/shared/config.py @@ -20,5 +20,6 @@ class Settings(BaseSettings): health_port: int = 8080 circuit_breaker_threshold: int = 5 circuit_breaker_timeout: int = 60 + metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token model_config = {"env_file": ".env", "env_file_encoding": "utf-8"} diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py index be02712..7411e8a 100644 --- a/shared/src/shared/healthcheck.py +++ b/shared/src/shared/healthcheck.py @@ -17,12 +17,14 @@ class HealthCheckServer: service_name: str, port: int = 8080, *, + auth_token: str = "", registry: CollectorRegistry | None = None, ) -> None: self.service_name = service_name self.port = port self._checks: dict[str, Callable[[], Awaitable[bool]]] = {} self._start_time = time.monotonic() + self._auth_token = auth_token self._registry = registry or REGISTRY def register_check(self, name: str, check_fn: Callable[[], Awaitable[bool]]) -> None: @@ -53,14 +55,25 @@ class HealthCheckServer: "checks": checks, } + def _check_auth(self, request: web.Request) -> bool: + """Return True if the request is authorised (or no token is configured).""" + if not self._auth_token: + return True + auth = request.headers.get("Authorization", "") + return auth == f"Bearer {self._auth_token}" + async def _handle_health(self, request: web.Request) -> web.Response: """GET /health — JSON health status.""" + if not self._check_auth(request): + return web.json_response({"error": "unauthorized"}, status=401) result = await self.run_checks() status_code = 200 if result["status"] == "ok" else 503 return web.json_response(result, status=status_code) async def _handle_metrics(self, request: web.Request) -> web.Response: """GET /metrics — Prometheus text exposition.""" + if not self._check_auth(request): + return web.Response(text="Unauthorized", status=401) output = generate_latest(self._registry) return web.Response(body=output, content_type=CONTENT_TYPE_LATEST) diff --git a/shared/tests/test_healthcheck.py b/shared/tests/test_healthcheck.py index 6970a8f..2f79757 100644 --- a/shared/tests/test_healthcheck.py +++ b/shared/tests/test_healthcheck.py @@ -1,6 +1,9 @@ """Tests for health check server.""" +from unittest.mock import MagicMock + import pytest +from multidict import CIMultiDict from prometheus_client import CollectorRegistry @@ -9,10 +12,17 @@ def registry(): return CollectorRegistry() -def make_server(service_name="test-service", port=8080, registry=None): +def make_server(service_name="test-service", port=8080, registry=None, auth_token=""): from shared.healthcheck import HealthCheckServer - return HealthCheckServer(service_name, port=port, registry=registry) + return HealthCheckServer(service_name, port=port, auth_token=auth_token, registry=registry) + + +def _fake_request(headers: dict | None = None) -> MagicMock: + """Create a minimal mock that quacks like aiohttp.web.Request.""" + req = MagicMock() + req.headers = CIMultiDict(headers or {}) + return req def test_init_defaults(registry): @@ -87,3 +97,31 @@ async def test_run_checks_false_is_fail(registry): result = await server.run_checks() assert result["status"] == "degraded" assert result["checks"]["cache"] == "fail" + + +# ── Bearer-token auth tests ──────────────────────────────────────── + + +def test_healthcheck_no_auth_when_token_empty(registry): + """When auth_token is empty, all requests pass auth regardless of headers.""" + server = make_server(registry=registry, auth_token="") + assert server._check_auth(_fake_request()) is True + assert server._check_auth(_fake_request({"Authorization": "Bearer wrong"})) is True + + +def test_healthcheck_auth_required_when_token_set(registry): + """When auth_token is set, a matching Bearer header passes auth.""" + server = make_server(registry=registry, auth_token="s3cret") + req = _fake_request({"Authorization": "Bearer s3cret"}) + assert server._check_auth(req) is True + + +def test_healthcheck_rejects_wrong_token(registry): + """When auth_token is set, a wrong or missing Bearer header is rejected.""" + server = make_server(registry=registry, auth_token="s3cret") + # Wrong token + assert server._check_auth(_fake_request({"Authorization": "Bearer bad"})) is False + # Missing header entirely + assert server._check_auth(_fake_request()) is False + # Malformed header + assert server._check_auth(_fake_request({"Authorization": "Token s3cret"})) is False -- cgit v1.2.3