summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.env.example1
-rw-r--r--monitoring/prometheus.yml3
-rw-r--r--services/data-collector/src/data_collector/main.py2
-rw-r--r--services/order-executor/src/order_executor/main.py2
-rw-r--r--services/portfolio-manager/src/portfolio_manager/main.py2
-rw-r--r--services/strategy-engine/src/strategy_engine/main.py2
-rw-r--r--shared/src/shared/config.py1
-rw-r--r--shared/src/shared/healthcheck.py13
-rw-r--r--shared/tests/test_healthcheck.py42
9 files changed, 62 insertions, 6 deletions
diff --git a/.env.example b/.env.example
index 428d388..9b3cf9c 100644
--- a/.env.example
+++ b/.env.example
@@ -14,3 +14,4 @@ LOG_FORMAT=json
HEALTH_PORT=8080
CIRCUIT_BREAKER_THRESHOLD=5
CIRCUIT_BREAKER_TIMEOUT=60
+METRICS_AUTH_TOKEN=
diff --git a/monitoring/prometheus.yml b/monitoring/prometheus.yml
index f6d5485..b6dc853 100644
--- a/monitoring/prometheus.yml
+++ b/monitoring/prometheus.yml
@@ -2,6 +2,9 @@ global:
scrape_interval: 15s
scrape_configs:
- job_name: "trading-services"
+ authorization:
+ type: Bearer
+ credentials: "${METRICS_AUTH_TOKEN}"
static_configs:
- targets:
- "data-collector:8080"
diff --git a/services/data-collector/src/data_collector/main.py b/services/data-collector/src/data_collector/main.py
index 170e8b1..81e13df 100644
--- a/services/data-collector/src/data_collector/main.py
+++ b/services/data-collector/src/data_collector/main.py
@@ -49,7 +49,7 @@ async def run() -> None:
on_candle=on_candle,
)
- health = HealthCheckServer("data-collector", port=config.health_port)
+ health = HealthCheckServer("data-collector", port=config.health_port, auth_token=config.metrics_auth_token)
health.register_check("redis", broker.ping)
await health.start()
metrics.service_up.labels(service="data-collector").set(1)
diff --git a/services/order-executor/src/order_executor/main.py b/services/order-executor/src/order_executor/main.py
index ab6ef4f..0198f65 100644
--- a/services/order-executor/src/order_executor/main.py
+++ b/services/order-executor/src/order_executor/main.py
@@ -57,7 +57,7 @@ async def run() -> None:
last_id = "$"
stream = "signals"
- health = HealthCheckServer("order-executor", port=config.health_port + 2)
+ health = HealthCheckServer("order-executor", port=config.health_port + 2, auth_token=config.metrics_auth_token)
health.register_check("redis", broker.ping)
await health.start()
metrics.service_up.labels(service="order-executor").set(1)
diff --git a/services/portfolio-manager/src/portfolio_manager/main.py b/services/portfolio-manager/src/portfolio_manager/main.py
index 02df5d2..be29a2f 100644
--- a/services/portfolio-manager/src/portfolio_manager/main.py
+++ b/services/portfolio-manager/src/portfolio_manager/main.py
@@ -63,7 +63,7 @@ async def run() -> None:
broker = RedisBroker(config.redis_url)
tracker = PortfolioTracker()
- health = HealthCheckServer("portfolio-manager", port=config.health_port + 3)
+ health = HealthCheckServer("portfolio-manager", port=config.health_port + 3, auth_token=config.metrics_auth_token)
health.register_check("redis", broker.ping)
await health.start()
metrics.service_up.labels(service="portfolio-manager").set(1)
diff --git a/services/strategy-engine/src/strategy_engine/main.py b/services/strategy-engine/src/strategy_engine/main.py
index 53681d1..fabd755 100644
--- a/services/strategy-engine/src/strategy_engine/main.py
+++ b/services/strategy-engine/src/strategy_engine/main.py
@@ -43,7 +43,7 @@ async def run() -> None:
engine = StrategyEngine(broker=broker, strategies=strategies)
- health = HealthCheckServer("strategy-engine", port=config.health_port + 1)
+ health = HealthCheckServer("strategy-engine", port=config.health_port + 1, auth_token=config.metrics_auth_token)
health.register_check("redis", broker.ping)
await health.start()
metrics.service_up.labels(service="strategy-engine").set(1)
diff --git a/shared/src/shared/config.py b/shared/src/shared/config.py
index 47bc2b1..7b34d78 100644
--- a/shared/src/shared/config.py
+++ b/shared/src/shared/config.py
@@ -20,5 +20,6 @@ class Settings(BaseSettings):
health_port: int = 8080
circuit_breaker_threshold: int = 5
circuit_breaker_timeout: int = 60
+ metrics_auth_token: str = "" # If set, /health and /metrics require Bearer token
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
diff --git a/shared/src/shared/healthcheck.py b/shared/src/shared/healthcheck.py
index be02712..7411e8a 100644
--- a/shared/src/shared/healthcheck.py
+++ b/shared/src/shared/healthcheck.py
@@ -17,12 +17,14 @@ class HealthCheckServer:
service_name: str,
port: int = 8080,
*,
+ auth_token: str = "",
registry: CollectorRegistry | None = None,
) -> None:
self.service_name = service_name
self.port = port
self._checks: dict[str, Callable[[], Awaitable[bool]]] = {}
self._start_time = time.monotonic()
+ self._auth_token = auth_token
self._registry = registry or REGISTRY
def register_check(self, name: str, check_fn: Callable[[], Awaitable[bool]]) -> None:
@@ -53,14 +55,25 @@ class HealthCheckServer:
"checks": checks,
}
+ def _check_auth(self, request: web.Request) -> bool:
+ """Return True if the request is authorised (or no token is configured)."""
+ if not self._auth_token:
+ return True
+ auth = request.headers.get("Authorization", "")
+ return auth == f"Bearer {self._auth_token}"
+
async def _handle_health(self, request: web.Request) -> web.Response:
"""GET /health — JSON health status."""
+ if not self._check_auth(request):
+ return web.json_response({"error": "unauthorized"}, status=401)
result = await self.run_checks()
status_code = 200 if result["status"] == "ok" else 503
return web.json_response(result, status=status_code)
async def _handle_metrics(self, request: web.Request) -> web.Response:
"""GET /metrics — Prometheus text exposition."""
+ if not self._check_auth(request):
+ return web.Response(text="Unauthorized", status=401)
output = generate_latest(self._registry)
return web.Response(body=output, content_type=CONTENT_TYPE_LATEST)
diff --git a/shared/tests/test_healthcheck.py b/shared/tests/test_healthcheck.py
index 6970a8f..2f79757 100644
--- a/shared/tests/test_healthcheck.py
+++ b/shared/tests/test_healthcheck.py
@@ -1,6 +1,9 @@
"""Tests for health check server."""
+from unittest.mock import MagicMock
+
import pytest
+from multidict import CIMultiDict
from prometheus_client import CollectorRegistry
@@ -9,10 +12,17 @@ def registry():
return CollectorRegistry()
-def make_server(service_name="test-service", port=8080, registry=None):
+def make_server(service_name="test-service", port=8080, registry=None, auth_token=""):
from shared.healthcheck import HealthCheckServer
- return HealthCheckServer(service_name, port=port, registry=registry)
+ return HealthCheckServer(service_name, port=port, auth_token=auth_token, registry=registry)
+
+
+def _fake_request(headers: dict | None = None) -> MagicMock:
+ """Create a minimal mock that quacks like aiohttp.web.Request."""
+ req = MagicMock()
+ req.headers = CIMultiDict(headers or {})
+ return req
def test_init_defaults(registry):
@@ -87,3 +97,31 @@ async def test_run_checks_false_is_fail(registry):
result = await server.run_checks()
assert result["status"] == "degraded"
assert result["checks"]["cache"] == "fail"
+
+
+# ── Bearer-token auth tests ────────────────────────────────────────
+
+
+def test_healthcheck_no_auth_when_token_empty(registry):
+ """When auth_token is empty, all requests pass auth regardless of headers."""
+ server = make_server(registry=registry, auth_token="")
+ assert server._check_auth(_fake_request()) is True
+ assert server._check_auth(_fake_request({"Authorization": "Bearer wrong"})) is True
+
+
+def test_healthcheck_auth_required_when_token_set(registry):
+ """When auth_token is set, a matching Bearer header passes auth."""
+ server = make_server(registry=registry, auth_token="s3cret")
+ req = _fake_request({"Authorization": "Bearer s3cret"})
+ assert server._check_auth(req) is True
+
+
+def test_healthcheck_rejects_wrong_token(registry):
+ """When auth_token is set, a wrong or missing Bearer header is rejected."""
+ server = make_server(registry=registry, auth_token="s3cret")
+ # Wrong token
+ assert server._check_auth(_fake_request({"Authorization": "Bearer bad"})) is False
+ # Missing header entirely
+ assert server._check_auth(_fake_request()) is False
+ # Malformed header
+ assert server._check_auth(_fake_request({"Authorization": "Token s3cret"})) is False